Skip to main content

torsh_hub/
cache.rs

1//! Cache management for ToRSh Hub
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use torsh_core::error::{Result, TorshError};
9
10/// Cache manager for hub models and repositories
11#[derive(Debug, Clone)]
12pub struct CacheManager {
13    cache_dir: PathBuf,
14    metadata_file: PathBuf,
15    metadata: CacheMetadata,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19struct CacheMetadata {
20    version: String,
21    repositories: HashMap<String, RepoMetadata>,
22    models: HashMap<String, ModelMetadata>,
23    /// Total number of cache lookups
24    #[serde(default)]
25    total_lookups: u64,
26    /// Number of successful cache hits
27    #[serde(default)]
28    cache_hits: u64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32struct RepoMetadata {
33    owner: String,
34    name: String,
35    branch: String,
36    last_updated: chrono::DateTime<chrono::Utc>,
37    size_bytes: u64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct ModelMetadata {
42    repo: String,
43    name: String,
44    version: String,
45    hash: String,
46    size_bytes: u64,
47    last_accessed: chrono::DateTime<chrono::Utc>,
48}
49
50impl CacheManager {
51    /// Create a new cache manager
52    pub fn new(cache_dir: &Path) -> Result<Self> {
53        fs::create_dir_all(cache_dir)?;
54
55        let metadata_file = cache_dir.join("cache_metadata.json");
56        let metadata = if metadata_file.exists() {
57            load_metadata(&metadata_file)?
58        } else {
59            CacheMetadata {
60                version: "1.0".to_string(),
61                repositories: HashMap::new(),
62                models: HashMap::new(),
63                total_lookups: 0,
64                cache_hits: 0,
65            }
66        };
67
68        Ok(Self {
69            cache_dir: cache_dir.to_path_buf(),
70            metadata_file,
71            metadata,
72        })
73    }
74
75    /// Get repository directory
76    pub fn get_repo_dir(&self, owner: &str, repo: &str, branch: &str) -> PathBuf {
77        self.cache_dir
78            .join("repositories")
79            .join(format!("{}__{}__{}", owner, repo, branch))
80    }
81
82    /// Get model cache path
83    pub fn get_model_path(&self, repo: &str, model: &str, version: &str) -> PathBuf {
84        self.cache_dir
85            .join("models")
86            .join(repo.replace('/', "__"))
87            .join(format!("{}_{}.torsh", model, version))
88    }
89
90    /// Check if repository is cached
91    pub fn is_repo_cached(&mut self, owner: &str, repo: &str, branch: &str) -> bool {
92        let key = format!("{}/{}/{}", owner, repo, branch);
93        self.metadata.total_lookups += 1;
94        let is_cached = self.metadata.repositories.contains_key(&key);
95        if is_cached {
96            self.metadata.cache_hits += 1;
97        }
98        is_cached
99    }
100
101    /// Check if model is cached
102    pub fn is_model_cached(&mut self, repo: &str, model: &str, version: &str) -> bool {
103        let key = format!("{}/{}@{}", repo, model, version);
104        self.metadata.total_lookups += 1;
105        let is_cached = self.metadata.models.contains_key(&key);
106        if is_cached {
107            self.metadata.cache_hits += 1;
108        }
109        is_cached
110    }
111
112    /// Add repository to cache
113    pub fn add_repo(
114        &mut self,
115        owner: &str,
116        repo: &str,
117        branch: &str,
118        size_bytes: u64,
119    ) -> Result<()> {
120        let key = format!("{}/{}/{}", owner, repo, branch);
121        self.metadata.repositories.insert(
122            key,
123            RepoMetadata {
124                owner: owner.to_string(),
125                name: repo.to_string(),
126                branch: branch.to_string(),
127                last_updated: chrono::Utc::now(),
128                size_bytes,
129            },
130        );
131
132        self.save_metadata()
133    }
134
135    /// Add model to cache
136    pub fn add_model(
137        &mut self,
138        repo: &str,
139        model: &str,
140        version: &str,
141        hash: &str,
142        size_bytes: u64,
143    ) -> Result<()> {
144        let key = format!("{}/{}@{}", repo, model, version);
145        self.metadata.models.insert(
146            key,
147            ModelMetadata {
148                repo: repo.to_string(),
149                name: model.to_string(),
150                version: version.to_string(),
151                hash: hash.to_string(),
152                size_bytes,
153                last_accessed: chrono::Utc::now(),
154            },
155        );
156
157        self.save_metadata()
158    }
159
160    /// Update model access time
161    pub fn touch_model(&mut self, repo: &str, model: &str, version: &str) -> Result<()> {
162        let key = format!("{}/{}@{}", repo, model, version);
163        if let Some(metadata) = self.metadata.models.get_mut(&key) {
164            metadata.last_accessed = chrono::Utc::now();
165            self.save_metadata()?;
166        }
167        Ok(())
168    }
169
170    /// Get cache statistics
171    pub fn get_cache_stats(&self) -> CacheStats {
172        let total_repos = self.metadata.repositories.len();
173        let total_models = self.metadata.models.len();
174        let total_size = self
175            .metadata
176            .repositories
177            .values()
178            .map(|r| r.size_bytes)
179            .sum::<u64>()
180            + self
181                .metadata
182                .models
183                .values()
184                .map(|m| m.size_bytes)
185                .sum::<u64>();
186
187        let oldest_access = self
188            .metadata
189            .models
190            .values()
191            .map(|m| m.last_accessed)
192            .min()
193            .unwrap_or_else(chrono::Utc::now);
194
195        let newest_access = self
196            .metadata
197            .models
198            .values()
199            .map(|m| m.last_accessed)
200            .max()
201            .unwrap_or_else(chrono::Utc::now);
202
203        CacheStats {
204            total_repositories: total_repos,
205            total_models,
206            total_size_bytes: total_size,
207            total_size_formatted: format_bytes(total_size),
208            oldest_access,
209            newest_access,
210            hit_rate: self.calculate_hit_rate(),
211        }
212    }
213
214    /// Perform cache cleanup based on size and age limits
215    pub fn cleanup_cache(
216        &mut self,
217        max_size_bytes: u64,
218        max_age_days: u32,
219    ) -> Result<CacheCleanupResult> {
220        let mut cleanup_result = CacheCleanupResult::default();
221        let cutoff_date = chrono::Utc::now() - chrono::Duration::days(max_age_days as i64);
222
223        // First pass: Remove old models
224        let mut models_to_remove = Vec::new();
225        for (key, metadata) in &self.metadata.models {
226            if metadata.last_accessed < cutoff_date {
227                models_to_remove.push(key.clone());
228            }
229        }
230
231        for key in models_to_remove {
232            if let Some(metadata) = self.metadata.models.remove(&key) {
233                let model_path =
234                    self.get_model_path(&metadata.repo, &metadata.name, &metadata.version);
235                if model_path.exists() {
236                    fs::remove_file(&model_path)?;
237                    cleanup_result.freed_bytes += metadata.size_bytes;
238                    cleanup_result.models_removed += 1;
239                }
240            }
241        }
242
243        // Second pass: Remove models by size if still over limit
244        let current_size = self.get_cache_stats().total_size_bytes;
245        if current_size > max_size_bytes {
246            let size_to_free = current_size - max_size_bytes;
247            let mut candidates: Vec<_> = self.metadata.models.iter().collect();
248            candidates.sort_by_key(|(_, metadata)| metadata.last_accessed);
249
250            let mut freed_bytes = 0u64;
251            let mut models_to_remove = Vec::new();
252
253            for (key, metadata) in candidates {
254                if freed_bytes >= size_to_free {
255                    break;
256                }
257                models_to_remove.push(key.clone());
258                freed_bytes += metadata.size_bytes;
259            }
260
261            for key in models_to_remove {
262                if let Some(metadata) = self.metadata.models.remove(&key) {
263                    let model_path =
264                        self.get_model_path(&metadata.repo, &metadata.name, &metadata.version);
265                    if model_path.exists() {
266                        fs::remove_file(&model_path)?;
267                        cleanup_result.freed_bytes += metadata.size_bytes;
268                        cleanup_result.models_removed += 1;
269                    }
270                }
271            }
272        }
273
274        self.save_metadata()?;
275        cleanup_result.cleanup_duration = std::time::Instant::now().elapsed();
276
277        Ok(cleanup_result)
278    }
279
280    /// Compress cache files to save disk space
281    pub fn compress_cache(&mut self) -> Result<CompressionResult> {
282        let mut compression_result = CompressionResult::default();
283        let cache_models_dir = self.cache_dir.join("models");
284
285        if !cache_models_dir.exists() {
286            return Ok(compression_result);
287        }
288
289        // First collect the keys and metadata info to avoid borrowing conflicts
290        let models_info: Vec<(String, String, String, String)> = self
291            .metadata
292            .models
293            .iter()
294            .map(|(key, metadata)| {
295                (
296                    key.clone(),
297                    metadata.repo.clone(),
298                    metadata.name.clone(),
299                    metadata.version.clone(),
300                )
301            })
302            .collect();
303
304        for (key, repo, name, version) in models_info {
305            let model_path = self.get_model_path(&repo, &name, &version);
306            let metadata = self
307                .metadata
308                .models
309                .get_mut(&key)
310                .expect("model metadata should exist in cache");
311            let compressed_path = model_path.with_extension("torsh.gz");
312
313            if model_path.exists() && !compressed_path.exists() {
314                match compress_file(&model_path, &compressed_path) {
315                    Ok(compression_stats) => {
316                        // Update metadata to point to compressed file
317                        let old_size = metadata.size_bytes;
318                        metadata.size_bytes = compression_stats.compressed_size;
319
320                        // Remove original file
321                        fs::remove_file(&model_path)?;
322
323                        compression_result.files_compressed += 1;
324                        compression_result.original_bytes += compression_stats.original_size;
325                        compression_result.compressed_bytes += compression_stats.compressed_size;
326                        compression_result.bytes_saved +=
327                            old_size - compression_stats.compressed_size;
328                    }
329                    Err(e) => {
330                        eprintln!("Failed to compress {}: {}", model_path.display(), e);
331                        compression_result.compression_failures += 1;
332                    }
333                }
334            }
335        }
336
337        self.save_metadata()?;
338        Ok(compression_result)
339    }
340
341    /// Calculate cache hit rate
342    fn calculate_hit_rate(&self) -> f32 {
343        if self.metadata.total_lookups == 0 {
344            0.0
345        } else {
346            self.metadata.cache_hits as f32 / self.metadata.total_lookups as f32
347        }
348    }
349
350    /// Validate cache integrity
351    pub fn validate_cache(&self) -> Result<CacheValidationResult> {
352        let mut validation_result = CacheValidationResult::default();
353
354        // Check model files
355        for (key, metadata) in &self.metadata.models {
356            let model_path = self.get_model_path(&metadata.repo, &metadata.name, &metadata.version);
357            let compressed_path = model_path.with_extension("torsh.gz");
358
359            if model_path.exists() {
360                validation_result.valid_models += 1;
361            } else if compressed_path.exists() {
362                validation_result.valid_models += 1;
363            } else {
364                validation_result.missing_files.push(key.clone());
365                validation_result.invalid_models += 1;
366            }
367        }
368
369        // Check repository directories
370        for (key, metadata) in &self.metadata.repositories {
371            let repo_dir = self.get_repo_dir(&metadata.owner, &metadata.name, &metadata.branch);
372            if repo_dir.exists() {
373                validation_result.valid_repositories += 1;
374            } else {
375                validation_result.missing_directories.push(key.clone());
376                validation_result.invalid_repositories += 1;
377            }
378        }
379
380        Ok(validation_result)
381    }
382
383    /// Get cache size
384    pub fn get_cache_size(&self) -> u64 {
385        let repo_size: u64 = self
386            .metadata
387            .repositories
388            .values()
389            .map(|r| r.size_bytes)
390            .sum();
391
392        let model_size: u64 = self.metadata.models.values().map(|m| m.size_bytes).sum();
393
394        repo_size + model_size
395    }
396
397    /// Clean old cache entries
398    pub fn clean_cache(
399        &mut self,
400        max_size_bytes: Option<u64>,
401        max_age_days: Option<u64>,
402    ) -> Result<()> {
403        let now = chrono::Utc::now();
404
405        // Remove old models
406        if let Some(max_age) = max_age_days {
407            let cutoff = now - chrono::Duration::days(max_age as i64);
408
409            let old_models: Vec<String> = self
410                .metadata
411                .models
412                .iter()
413                .filter(|(_, m)| m.last_accessed < cutoff)
414                .map(|(k, _)| k.clone())
415                .collect();
416
417            for key in old_models {
418                if let Some(model) = self.metadata.models.remove(&key) {
419                    let path = self.get_model_path(&model.repo, &model.name, &model.version);
420                    let _ = fs::remove_file(path);
421                }
422            }
423        }
424
425        // Remove to fit size limit
426        if let Some(max_size) = max_size_bytes {
427            while self.get_cache_size() > max_size {
428                // Find least recently used model
429                let lru_key = self
430                    .metadata
431                    .models
432                    .iter()
433                    .min_by_key(|(_, m)| m.last_accessed)
434                    .map(|(k, _)| k.clone());
435
436                if let Some(key) = lru_key {
437                    if let Some(model) = self.metadata.models.remove(&key) {
438                        let path = self.get_model_path(&model.repo, &model.name, &model.version);
439                        let _ = fs::remove_file(path);
440                    }
441                } else {
442                    break;
443                }
444            }
445        }
446
447        self.save_metadata()
448    }
449
450    /// Clear entire cache
451    pub fn clear_cache(&mut self) -> Result<()> {
452        // Remove all files
453        if self.cache_dir.exists() {
454            fs::remove_dir_all(&self.cache_dir)?;
455            fs::create_dir_all(&self.cache_dir)?;
456        }
457
458        // Reset metadata
459        self.metadata.repositories.clear();
460        self.metadata.models.clear();
461
462        self.save_metadata()
463    }
464
465    /// Save metadata to disk
466    fn save_metadata(&self) -> Result<()> {
467        let json = serde_json::to_string_pretty(&self.metadata)
468            .map_err(|e| TorshError::SerializationError(e.to_string()))?;
469        let mut file = File::create(&self.metadata_file)?;
470        file.write_all(json.as_bytes())?;
471        Ok(())
472    }
473}
474
475/// Load metadata from file
476fn load_metadata(path: &Path) -> Result<CacheMetadata> {
477    let mut file = File::open(path)?;
478    let mut content = String::new();
479    file.read_to_string(&mut content)?;
480
481    serde_json::from_str(&content).map_err(|e| TorshError::SerializationError(e.to_string()))
482}
483
484/// Get directory size recursively
485pub fn get_dir_size(path: &Path) -> Result<u64> {
486    let mut size = 0;
487
488    if path.is_dir() {
489        for entry in fs::read_dir(path)? {
490            let entry = entry?;
491            let path = entry.path();
492
493            if path.is_dir() {
494                size += get_dir_size(&path)?;
495            } else {
496                size += entry.metadata()?.len();
497            }
498        }
499    } else if path.is_file() {
500        size = fs::metadata(path)?.len();
501    }
502
503    Ok(size)
504}
505
506/// Verify file hash
507pub fn verify_file_hash(path: &Path, expected_hash: &str) -> Result<bool> {
508    use sha2::{Digest, Sha256};
509
510    let mut file = File::open(path)?;
511    let mut hasher = Sha256::new();
512    let mut buffer = [0; 8192];
513
514    loop {
515        let n = file.read(&mut buffer)?;
516        if n == 0 {
517            break;
518        }
519        hasher.update(&buffer[..n]);
520    }
521
522    let result = hasher.finalize();
523    let actual_hash = hex::encode(result);
524
525    Ok(actual_hash == expected_hash)
526}
527
528/// Cache statistics structure
529#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct CacheStats {
531    pub total_repositories: usize,
532    pub total_models: usize,
533    pub total_size_bytes: u64,
534    pub total_size_formatted: String,
535    pub oldest_access: chrono::DateTime<chrono::Utc>,
536    pub newest_access: chrono::DateTime<chrono::Utc>,
537    pub hit_rate: f32,
538}
539
540/// Cache cleanup result
541#[derive(Debug, Clone, Default)]
542pub struct CacheCleanupResult {
543    pub models_removed: usize,
544    pub repositories_removed: usize,
545    pub freed_bytes: u64,
546    pub cleanup_duration: std::time::Duration,
547}
548
549/// Compression result
550#[derive(Debug, Clone, Default)]
551pub struct CompressionResult {
552    pub files_compressed: usize,
553    pub original_bytes: u64,
554    pub compressed_bytes: u64,
555    pub bytes_saved: u64,
556    pub compression_failures: usize,
557}
558
559/// Cache validation result
560#[derive(Debug, Clone, Default)]
561pub struct CacheValidationResult {
562    pub valid_models: usize,
563    pub invalid_models: usize,
564    pub valid_repositories: usize,
565    pub invalid_repositories: usize,
566    pub missing_files: Vec<String>,
567    pub missing_directories: Vec<String>,
568}
569
570/// File compression statistics
571#[derive(Debug, Clone)]
572pub struct FileCompressionStats {
573    pub original_size: u64,
574    pub compressed_size: u64,
575    pub compression_ratio: f32,
576}
577
578/// Format bytes into human-readable string
579pub fn format_bytes(bytes: u64) -> String {
580    const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
581    const THRESHOLD: f64 = 1024.0;
582
583    if bytes == 0 {
584        return "0 B".to_string();
585    }
586
587    let mut size = bytes as f64;
588    let mut unit_index = 0;
589
590    while size >= THRESHOLD && unit_index < UNITS.len() - 1 {
591        size /= THRESHOLD;
592        unit_index += 1;
593    }
594
595    if unit_index == 0 {
596        format!("{} {}", bytes, UNITS[unit_index])
597    } else {
598        format!("{:.1} {}", size, UNITS[unit_index])
599    }
600}
601
602/// Compress a file using gzip compression
603pub fn compress_file(input_path: &Path, output_path: &Path) -> Result<FileCompressionStats> {
604    use flate2::write::GzEncoder;
605    use flate2::Compression;
606    use std::io::copy;
607
608    let input_file = File::open(input_path)?;
609    let output_file = File::create(output_path)?;
610    let mut encoder = GzEncoder::new(output_file, Compression::default());
611
612    let original_size = input_file.metadata()?.len();
613
614    let mut input_reader = std::io::BufReader::new(input_file);
615    copy(&mut input_reader, &mut encoder)?;
616    encoder.finish()?;
617
618    let compressed_size = fs::metadata(output_path)?.len();
619    let compression_ratio = if original_size > 0 {
620        compressed_size as f32 / original_size as f32
621    } else {
622        1.0
623    };
624
625    Ok(FileCompressionStats {
626        original_size,
627        compressed_size,
628        compression_ratio,
629    })
630}
631
632/// Decompress a gzip file
633pub fn decompress_file(input_path: &Path, output_path: &Path) -> Result<()> {
634    use flate2::read::GzDecoder;
635    use std::io::copy;
636
637    let input_file = File::open(input_path)?;
638    let mut decoder = GzDecoder::new(input_file);
639    let mut output_file = File::create(output_path)?;
640
641    copy(&mut decoder, &mut output_file)?;
642    Ok(())
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use tempfile::TempDir;
649
650    #[test]
651    fn test_cache_manager() {
652        let temp_dir = TempDir::new().unwrap();
653        let mut cache = CacheManager::new(temp_dir.path()).unwrap();
654
655        // Test repository caching
656        assert!(!cache.is_repo_cached("owner", "repo", "main"));
657        cache.add_repo("owner", "repo", "main", 1000).unwrap();
658        assert!(cache.is_repo_cached("owner", "repo", "main"));
659
660        // Test model caching
661        assert!(!cache.is_model_cached("owner/repo", "model", "v1.0"));
662        cache
663            .add_model("owner/repo", "model", "v1.0", "hash123", 2000)
664            .unwrap();
665        assert!(cache.is_model_cached("owner/repo", "model", "v1.0"));
666
667        // Test cache size
668        assert_eq!(cache.get_cache_size(), 3000);
669    }
670
671    #[test]
672    fn test_cache_cleaning() {
673        let temp_dir = TempDir::new().unwrap();
674        let mut cache = CacheManager::new(temp_dir.path()).unwrap();
675
676        // Add some models
677        cache
678            .add_model("repo1", "model1", "v1", "hash1", 1000)
679            .unwrap();
680
681        // Sleep a tiny bit to ensure different timestamps
682        std::thread::sleep(std::time::Duration::from_millis(10));
683
684        cache
685            .add_model("repo2", "model2", "v1", "hash2", 2000)
686            .unwrap();
687
688        // Touch the first model to make it more recently accessed
689        cache.touch_model("repo1", "model1", "v1").unwrap();
690
691        // Clean with size limit - should remove the least recently accessed (model2)
692        cache.clean_cache(Some(1500), None).unwrap();
693
694        // Should have removed model2 (2000 bytes) and kept model1 (1000 bytes)
695        assert_eq!(cache.metadata.models.len(), 1);
696        assert!(cache.is_model_cached("repo1", "model1", "v1"));
697    }
698}