1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::{self, File};
6use std::io::{Read, Write};
7use std::path::{Path, PathBuf};
8use torsh_core::error::{Result, TorshError};
9
10#[derive(Debug, Clone)]
12pub struct CacheManager {
13 cache_dir: PathBuf,
14 metadata_file: PathBuf,
15 metadata: CacheMetadata,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19struct CacheMetadata {
20 version: String,
21 repositories: HashMap<String, RepoMetadata>,
22 models: HashMap<String, ModelMetadata>,
23 #[serde(default)]
25 total_lookups: u64,
26 #[serde(default)]
28 cache_hits: u64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32struct RepoMetadata {
33 owner: String,
34 name: String,
35 branch: String,
36 last_updated: chrono::DateTime<chrono::Utc>,
37 size_bytes: u64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct ModelMetadata {
42 repo: String,
43 name: String,
44 version: String,
45 hash: String,
46 size_bytes: u64,
47 last_accessed: chrono::DateTime<chrono::Utc>,
48}
49
50impl CacheManager {
51 pub fn new(cache_dir: &Path) -> Result<Self> {
53 fs::create_dir_all(cache_dir)?;
54
55 let metadata_file = cache_dir.join("cache_metadata.json");
56 let metadata = if metadata_file.exists() {
57 load_metadata(&metadata_file)?
58 } else {
59 CacheMetadata {
60 version: "1.0".to_string(),
61 repositories: HashMap::new(),
62 models: HashMap::new(),
63 total_lookups: 0,
64 cache_hits: 0,
65 }
66 };
67
68 Ok(Self {
69 cache_dir: cache_dir.to_path_buf(),
70 metadata_file,
71 metadata,
72 })
73 }
74
75 pub fn get_repo_dir(&self, owner: &str, repo: &str, branch: &str) -> PathBuf {
77 self.cache_dir
78 .join("repositories")
79 .join(format!("{}__{}__{}", owner, repo, branch))
80 }
81
82 pub fn get_model_path(&self, repo: &str, model: &str, version: &str) -> PathBuf {
84 self.cache_dir
85 .join("models")
86 .join(repo.replace('/', "__"))
87 .join(format!("{}_{}.torsh", model, version))
88 }
89
90 pub fn is_repo_cached(&mut self, owner: &str, repo: &str, branch: &str) -> bool {
92 let key = format!("{}/{}/{}", owner, repo, branch);
93 self.metadata.total_lookups += 1;
94 let is_cached = self.metadata.repositories.contains_key(&key);
95 if is_cached {
96 self.metadata.cache_hits += 1;
97 }
98 is_cached
99 }
100
101 pub fn is_model_cached(&mut self, repo: &str, model: &str, version: &str) -> bool {
103 let key = format!("{}/{}@{}", repo, model, version);
104 self.metadata.total_lookups += 1;
105 let is_cached = self.metadata.models.contains_key(&key);
106 if is_cached {
107 self.metadata.cache_hits += 1;
108 }
109 is_cached
110 }
111
112 pub fn add_repo(
114 &mut self,
115 owner: &str,
116 repo: &str,
117 branch: &str,
118 size_bytes: u64,
119 ) -> Result<()> {
120 let key = format!("{}/{}/{}", owner, repo, branch);
121 self.metadata.repositories.insert(
122 key,
123 RepoMetadata {
124 owner: owner.to_string(),
125 name: repo.to_string(),
126 branch: branch.to_string(),
127 last_updated: chrono::Utc::now(),
128 size_bytes,
129 },
130 );
131
132 self.save_metadata()
133 }
134
135 pub fn add_model(
137 &mut self,
138 repo: &str,
139 model: &str,
140 version: &str,
141 hash: &str,
142 size_bytes: u64,
143 ) -> Result<()> {
144 let key = format!("{}/{}@{}", repo, model, version);
145 self.metadata.models.insert(
146 key,
147 ModelMetadata {
148 repo: repo.to_string(),
149 name: model.to_string(),
150 version: version.to_string(),
151 hash: hash.to_string(),
152 size_bytes,
153 last_accessed: chrono::Utc::now(),
154 },
155 );
156
157 self.save_metadata()
158 }
159
160 pub fn touch_model(&mut self, repo: &str, model: &str, version: &str) -> Result<()> {
162 let key = format!("{}/{}@{}", repo, model, version);
163 if let Some(metadata) = self.metadata.models.get_mut(&key) {
164 metadata.last_accessed = chrono::Utc::now();
165 self.save_metadata()?;
166 }
167 Ok(())
168 }
169
170 pub fn get_cache_stats(&self) -> CacheStats {
172 let total_repos = self.metadata.repositories.len();
173 let total_models = self.metadata.models.len();
174 let total_size = self
175 .metadata
176 .repositories
177 .values()
178 .map(|r| r.size_bytes)
179 .sum::<u64>()
180 + self
181 .metadata
182 .models
183 .values()
184 .map(|m| m.size_bytes)
185 .sum::<u64>();
186
187 let oldest_access = self
188 .metadata
189 .models
190 .values()
191 .map(|m| m.last_accessed)
192 .min()
193 .unwrap_or_else(chrono::Utc::now);
194
195 let newest_access = self
196 .metadata
197 .models
198 .values()
199 .map(|m| m.last_accessed)
200 .max()
201 .unwrap_or_else(chrono::Utc::now);
202
203 CacheStats {
204 total_repositories: total_repos,
205 total_models,
206 total_size_bytes: total_size,
207 total_size_formatted: format_bytes(total_size),
208 oldest_access,
209 newest_access,
210 hit_rate: self.calculate_hit_rate(),
211 }
212 }
213
214 pub fn cleanup_cache(
216 &mut self,
217 max_size_bytes: u64,
218 max_age_days: u32,
219 ) -> Result<CacheCleanupResult> {
220 let mut cleanup_result = CacheCleanupResult::default();
221 let cutoff_date = chrono::Utc::now() - chrono::Duration::days(max_age_days as i64);
222
223 let mut models_to_remove = Vec::new();
225 for (key, metadata) in &self.metadata.models {
226 if metadata.last_accessed < cutoff_date {
227 models_to_remove.push(key.clone());
228 }
229 }
230
231 for key in models_to_remove {
232 if let Some(metadata) = self.metadata.models.remove(&key) {
233 let model_path =
234 self.get_model_path(&metadata.repo, &metadata.name, &metadata.version);
235 if model_path.exists() {
236 fs::remove_file(&model_path)?;
237 cleanup_result.freed_bytes += metadata.size_bytes;
238 cleanup_result.models_removed += 1;
239 }
240 }
241 }
242
243 let current_size = self.get_cache_stats().total_size_bytes;
245 if current_size > max_size_bytes {
246 let size_to_free = current_size - max_size_bytes;
247 let mut candidates: Vec<_> = self.metadata.models.iter().collect();
248 candidates.sort_by_key(|(_, metadata)| metadata.last_accessed);
249
250 let mut freed_bytes = 0u64;
251 let mut models_to_remove = Vec::new();
252
253 for (key, metadata) in candidates {
254 if freed_bytes >= size_to_free {
255 break;
256 }
257 models_to_remove.push(key.clone());
258 freed_bytes += metadata.size_bytes;
259 }
260
261 for key in models_to_remove {
262 if let Some(metadata) = self.metadata.models.remove(&key) {
263 let model_path =
264 self.get_model_path(&metadata.repo, &metadata.name, &metadata.version);
265 if model_path.exists() {
266 fs::remove_file(&model_path)?;
267 cleanup_result.freed_bytes += metadata.size_bytes;
268 cleanup_result.models_removed += 1;
269 }
270 }
271 }
272 }
273
274 self.save_metadata()?;
275 cleanup_result.cleanup_duration = std::time::Instant::now().elapsed();
276
277 Ok(cleanup_result)
278 }
279
280 pub fn compress_cache(&mut self) -> Result<CompressionResult> {
282 let mut compression_result = CompressionResult::default();
283 let cache_models_dir = self.cache_dir.join("models");
284
285 if !cache_models_dir.exists() {
286 return Ok(compression_result);
287 }
288
289 let models_info: Vec<(String, String, String, String)> = self
291 .metadata
292 .models
293 .iter()
294 .map(|(key, metadata)| {
295 (
296 key.clone(),
297 metadata.repo.clone(),
298 metadata.name.clone(),
299 metadata.version.clone(),
300 )
301 })
302 .collect();
303
304 for (key, repo, name, version) in models_info {
305 let model_path = self.get_model_path(&repo, &name, &version);
306 let metadata = self
307 .metadata
308 .models
309 .get_mut(&key)
310 .expect("model metadata should exist in cache");
311 let compressed_path = model_path.with_extension("torsh.gz");
312
313 if model_path.exists() && !compressed_path.exists() {
314 match compress_file(&model_path, &compressed_path) {
315 Ok(compression_stats) => {
316 let old_size = metadata.size_bytes;
318 metadata.size_bytes = compression_stats.compressed_size;
319
320 fs::remove_file(&model_path)?;
322
323 compression_result.files_compressed += 1;
324 compression_result.original_bytes += compression_stats.original_size;
325 compression_result.compressed_bytes += compression_stats.compressed_size;
326 compression_result.bytes_saved +=
327 old_size - compression_stats.compressed_size;
328 }
329 Err(e) => {
330 eprintln!("Failed to compress {}: {}", model_path.display(), e);
331 compression_result.compression_failures += 1;
332 }
333 }
334 }
335 }
336
337 self.save_metadata()?;
338 Ok(compression_result)
339 }
340
341 fn calculate_hit_rate(&self) -> f32 {
343 if self.metadata.total_lookups == 0 {
344 0.0
345 } else {
346 self.metadata.cache_hits as f32 / self.metadata.total_lookups as f32
347 }
348 }
349
350 pub fn validate_cache(&self) -> Result<CacheValidationResult> {
352 let mut validation_result = CacheValidationResult::default();
353
354 for (key, metadata) in &self.metadata.models {
356 let model_path = self.get_model_path(&metadata.repo, &metadata.name, &metadata.version);
357 let compressed_path = model_path.with_extension("torsh.gz");
358
359 if model_path.exists() {
360 validation_result.valid_models += 1;
361 } else if compressed_path.exists() {
362 validation_result.valid_models += 1;
363 } else {
364 validation_result.missing_files.push(key.clone());
365 validation_result.invalid_models += 1;
366 }
367 }
368
369 for (key, metadata) in &self.metadata.repositories {
371 let repo_dir = self.get_repo_dir(&metadata.owner, &metadata.name, &metadata.branch);
372 if repo_dir.exists() {
373 validation_result.valid_repositories += 1;
374 } else {
375 validation_result.missing_directories.push(key.clone());
376 validation_result.invalid_repositories += 1;
377 }
378 }
379
380 Ok(validation_result)
381 }
382
383 pub fn get_cache_size(&self) -> u64 {
385 let repo_size: u64 = self
386 .metadata
387 .repositories
388 .values()
389 .map(|r| r.size_bytes)
390 .sum();
391
392 let model_size: u64 = self.metadata.models.values().map(|m| m.size_bytes).sum();
393
394 repo_size + model_size
395 }
396
397 pub fn clean_cache(
399 &mut self,
400 max_size_bytes: Option<u64>,
401 max_age_days: Option<u64>,
402 ) -> Result<()> {
403 let now = chrono::Utc::now();
404
405 if let Some(max_age) = max_age_days {
407 let cutoff = now - chrono::Duration::days(max_age as i64);
408
409 let old_models: Vec<String> = self
410 .metadata
411 .models
412 .iter()
413 .filter(|(_, m)| m.last_accessed < cutoff)
414 .map(|(k, _)| k.clone())
415 .collect();
416
417 for key in old_models {
418 if let Some(model) = self.metadata.models.remove(&key) {
419 let path = self.get_model_path(&model.repo, &model.name, &model.version);
420 let _ = fs::remove_file(path);
421 }
422 }
423 }
424
425 if let Some(max_size) = max_size_bytes {
427 while self.get_cache_size() > max_size {
428 let lru_key = self
430 .metadata
431 .models
432 .iter()
433 .min_by_key(|(_, m)| m.last_accessed)
434 .map(|(k, _)| k.clone());
435
436 if let Some(key) = lru_key {
437 if let Some(model) = self.metadata.models.remove(&key) {
438 let path = self.get_model_path(&model.repo, &model.name, &model.version);
439 let _ = fs::remove_file(path);
440 }
441 } else {
442 break;
443 }
444 }
445 }
446
447 self.save_metadata()
448 }
449
450 pub fn clear_cache(&mut self) -> Result<()> {
452 if self.cache_dir.exists() {
454 fs::remove_dir_all(&self.cache_dir)?;
455 fs::create_dir_all(&self.cache_dir)?;
456 }
457
458 self.metadata.repositories.clear();
460 self.metadata.models.clear();
461
462 self.save_metadata()
463 }
464
465 fn save_metadata(&self) -> Result<()> {
467 let json = serde_json::to_string_pretty(&self.metadata)
468 .map_err(|e| TorshError::SerializationError(e.to_string()))?;
469 let mut file = File::create(&self.metadata_file)?;
470 file.write_all(json.as_bytes())?;
471 Ok(())
472 }
473}
474
475fn load_metadata(path: &Path) -> Result<CacheMetadata> {
477 let mut file = File::open(path)?;
478 let mut content = String::new();
479 file.read_to_string(&mut content)?;
480
481 serde_json::from_str(&content).map_err(|e| TorshError::SerializationError(e.to_string()))
482}
483
484pub fn get_dir_size(path: &Path) -> Result<u64> {
486 let mut size = 0;
487
488 if path.is_dir() {
489 for entry in fs::read_dir(path)? {
490 let entry = entry?;
491 let path = entry.path();
492
493 if path.is_dir() {
494 size += get_dir_size(&path)?;
495 } else {
496 size += entry.metadata()?.len();
497 }
498 }
499 } else if path.is_file() {
500 size = fs::metadata(path)?.len();
501 }
502
503 Ok(size)
504}
505
506pub fn verify_file_hash(path: &Path, expected_hash: &str) -> Result<bool> {
508 use sha2::{Digest, Sha256};
509
510 let mut file = File::open(path)?;
511 let mut hasher = Sha256::new();
512 let mut buffer = [0; 8192];
513
514 loop {
515 let n = file.read(&mut buffer)?;
516 if n == 0 {
517 break;
518 }
519 hasher.update(&buffer[..n]);
520 }
521
522 let result = hasher.finalize();
523 let actual_hash = hex::encode(result);
524
525 Ok(actual_hash == expected_hash)
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct CacheStats {
531 pub total_repositories: usize,
532 pub total_models: usize,
533 pub total_size_bytes: u64,
534 pub total_size_formatted: String,
535 pub oldest_access: chrono::DateTime<chrono::Utc>,
536 pub newest_access: chrono::DateTime<chrono::Utc>,
537 pub hit_rate: f32,
538}
539
540#[derive(Debug, Clone, Default)]
542pub struct CacheCleanupResult {
543 pub models_removed: usize,
544 pub repositories_removed: usize,
545 pub freed_bytes: u64,
546 pub cleanup_duration: std::time::Duration,
547}
548
549#[derive(Debug, Clone, Default)]
551pub struct CompressionResult {
552 pub files_compressed: usize,
553 pub original_bytes: u64,
554 pub compressed_bytes: u64,
555 pub bytes_saved: u64,
556 pub compression_failures: usize,
557}
558
559#[derive(Debug, Clone, Default)]
561pub struct CacheValidationResult {
562 pub valid_models: usize,
563 pub invalid_models: usize,
564 pub valid_repositories: usize,
565 pub invalid_repositories: usize,
566 pub missing_files: Vec<String>,
567 pub missing_directories: Vec<String>,
568}
569
570#[derive(Debug, Clone)]
572pub struct FileCompressionStats {
573 pub original_size: u64,
574 pub compressed_size: u64,
575 pub compression_ratio: f32,
576}
577
578pub fn format_bytes(bytes: u64) -> String {
580 const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
581 const THRESHOLD: f64 = 1024.0;
582
583 if bytes == 0 {
584 return "0 B".to_string();
585 }
586
587 let mut size = bytes as f64;
588 let mut unit_index = 0;
589
590 while size >= THRESHOLD && unit_index < UNITS.len() - 1 {
591 size /= THRESHOLD;
592 unit_index += 1;
593 }
594
595 if unit_index == 0 {
596 format!("{} {}", bytes, UNITS[unit_index])
597 } else {
598 format!("{:.1} {}", size, UNITS[unit_index])
599 }
600}
601
602pub fn compress_file(input_path: &Path, output_path: &Path) -> Result<FileCompressionStats> {
604 use flate2::write::GzEncoder;
605 use flate2::Compression;
606 use std::io::copy;
607
608 let input_file = File::open(input_path)?;
609 let output_file = File::create(output_path)?;
610 let mut encoder = GzEncoder::new(output_file, Compression::default());
611
612 let original_size = input_file.metadata()?.len();
613
614 let mut input_reader = std::io::BufReader::new(input_file);
615 copy(&mut input_reader, &mut encoder)?;
616 encoder.finish()?;
617
618 let compressed_size = fs::metadata(output_path)?.len();
619 let compression_ratio = if original_size > 0 {
620 compressed_size as f32 / original_size as f32
621 } else {
622 1.0
623 };
624
625 Ok(FileCompressionStats {
626 original_size,
627 compressed_size,
628 compression_ratio,
629 })
630}
631
632pub fn decompress_file(input_path: &Path, output_path: &Path) -> Result<()> {
634 use flate2::read::GzDecoder;
635 use std::io::copy;
636
637 let input_file = File::open(input_path)?;
638 let mut decoder = GzDecoder::new(input_file);
639 let mut output_file = File::create(output_path)?;
640
641 copy(&mut decoder, &mut output_file)?;
642 Ok(())
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use tempfile::TempDir;
649
650 #[test]
651 fn test_cache_manager() {
652 let temp_dir = TempDir::new().unwrap();
653 let mut cache = CacheManager::new(temp_dir.path()).unwrap();
654
655 assert!(!cache.is_repo_cached("owner", "repo", "main"));
657 cache.add_repo("owner", "repo", "main", 1000).unwrap();
658 assert!(cache.is_repo_cached("owner", "repo", "main"));
659
660 assert!(!cache.is_model_cached("owner/repo", "model", "v1.0"));
662 cache
663 .add_model("owner/repo", "model", "v1.0", "hash123", 2000)
664 .unwrap();
665 assert!(cache.is_model_cached("owner/repo", "model", "v1.0"));
666
667 assert_eq!(cache.get_cache_size(), 3000);
669 }
670
671 #[test]
672 fn test_cache_cleaning() {
673 let temp_dir = TempDir::new().unwrap();
674 let mut cache = CacheManager::new(temp_dir.path()).unwrap();
675
676 cache
678 .add_model("repo1", "model1", "v1", "hash1", 1000)
679 .unwrap();
680
681 std::thread::sleep(std::time::Duration::from_millis(10));
683
684 cache
685 .add_model("repo2", "model2", "v1", "hash2", 2000)
686 .unwrap();
687
688 cache.touch_model("repo1", "model1", "v1").unwrap();
690
691 cache.clean_cache(Some(1500), None).unwrap();
693
694 assert_eq!(cache.metadata.models.len(), 1);
696 assert!(cache.is_model_cached("repo1", "model1", "v1"));
697 }
698}