1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3
4use crate::api::client::ApiClient;
5use crate::manifest::operations::get_after_hash_blobs;
6use crate::manifest::schema::PatchManifest;
7use crate::patch::apply::PatchSources;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum DownloadMode {
18 Diff,
19 Package,
20 File,
21}
22
23impl DownloadMode {
24 pub fn as_tag(&self) -> &'static str {
27 match self {
28 DownloadMode::Diff => "diff",
29 DownloadMode::Package => "package",
30 DownloadMode::File => "file",
31 }
32 }
33
34 pub fn parse(s: &str) -> Result<Self, String> {
36 match s.to_ascii_lowercase().as_str() {
37 "diff" => Ok(DownloadMode::Diff),
38 "package" => Ok(DownloadMode::Package),
39 "file" | "blob" => Ok(DownloadMode::File),
40 other => Err(format!(
41 "unknown download mode '{}'. Expected diff, package, or file.",
42 other
43 )),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct BlobFetchResult {
51 pub hash: String,
52 pub success: bool,
53 pub error: Option<String>,
54}
55
56#[derive(Debug, Clone)]
58pub struct FetchMissingBlobsResult {
59 pub total: usize,
60 pub downloaded: usize,
61 pub failed: usize,
62 pub skipped: usize,
63 pub results: Vec<BlobFetchResult>,
64}
65
66pub type OnProgress = Box<dyn Fn(&str, usize, usize) + Send + Sync>;
70
71pub async fn get_missing_blobs(
80 manifest: &PatchManifest,
81 blobs_path: &Path,
82) -> HashSet<String> {
83 let after_hash_blobs = get_after_hash_blobs(manifest);
84 let mut missing = HashSet::new();
85
86 for hash in after_hash_blobs {
87 let blob_path = blobs_path.join(&hash);
88 if tokio::fs::metadata(&blob_path).await.is_err() {
89 missing.insert(hash);
90 }
91 }
92
93 missing
94}
95
96pub async fn fetch_missing_blobs(
109 manifest: &PatchManifest,
110 blobs_path: &Path,
111 client: &ApiClient,
112 on_progress: Option<&OnProgress>,
113) -> FetchMissingBlobsResult {
114 let missing = get_missing_blobs(manifest, blobs_path).await;
115
116 if missing.is_empty() {
117 return FetchMissingBlobsResult {
118 total: 0,
119 downloaded: 0,
120 failed: 0,
121 skipped: 0,
122 results: Vec::new(),
123 };
124 }
125
126 if let Err(e) = tokio::fs::create_dir_all(blobs_path).await {
128 let results: Vec<BlobFetchResult> = missing
130 .iter()
131 .map(|h| BlobFetchResult {
132 hash: h.clone(),
133 success: false,
134 error: Some(format!("Cannot create blobs directory: {}", e)),
135 })
136 .collect();
137 let failed = results.len();
138 return FetchMissingBlobsResult {
139 total: failed,
140 downloaded: 0,
141 failed,
142 skipped: 0,
143 results,
144 };
145 }
146
147 let hashes: Vec<String> = missing.into_iter().collect();
148 download_hashes(&hashes, blobs_path, client, on_progress).await
149}
150
151pub async fn fetch_blobs_by_hash(
158 hashes: &HashSet<String>,
159 blobs_path: &Path,
160 client: &ApiClient,
161 on_progress: Option<&OnProgress>,
162) -> FetchMissingBlobsResult {
163 if hashes.is_empty() {
164 return FetchMissingBlobsResult {
165 total: 0,
166 downloaded: 0,
167 failed: 0,
168 skipped: 0,
169 results: Vec::new(),
170 };
171 }
172
173 if let Err(e) = tokio::fs::create_dir_all(blobs_path).await {
175 let results: Vec<BlobFetchResult> = hashes
176 .iter()
177 .map(|h| BlobFetchResult {
178 hash: h.clone(),
179 success: false,
180 error: Some(format!("Cannot create blobs directory: {}", e)),
181 })
182 .collect();
183 let failed = results.len();
184 return FetchMissingBlobsResult {
185 total: failed,
186 downloaded: 0,
187 failed,
188 skipped: 0,
189 results,
190 };
191 }
192
193 let mut to_download: Vec<String> = Vec::new();
195 let mut skipped: usize = 0;
196 let mut results: Vec<BlobFetchResult> = Vec::new();
197
198 for hash in hashes {
199 let blob_path = blobs_path.join(hash);
200 if tokio::fs::metadata(&blob_path).await.is_ok() {
201 skipped += 1;
202 results.push(BlobFetchResult {
203 hash: hash.clone(),
204 success: true,
205 error: None,
206 });
207 } else {
208 to_download.push(hash.clone());
209 }
210 }
211
212 if to_download.is_empty() {
213 return FetchMissingBlobsResult {
214 total: hashes.len(),
215 downloaded: 0,
216 failed: 0,
217 skipped,
218 results,
219 };
220 }
221
222 let download_result =
223 download_hashes(&to_download, blobs_path, client, on_progress).await;
224
225 FetchMissingBlobsResult {
226 total: hashes.len(),
227 downloaded: download_result.downloaded,
228 failed: download_result.failed,
229 skipped,
230 results: {
231 let mut combined = results;
232 combined.extend(download_result.results);
233 combined
234 },
235 }
236}
237
238pub async fn get_missing_archives(
242 manifest: &PatchManifest,
243 archives_dir: &Path,
244) -> HashSet<String> {
245 let mut missing = HashSet::new();
246 for record in manifest.patches.values() {
247 let archive_path = archives_dir.join(format!("{}.tar.gz", record.uuid));
248 if tokio::fs::metadata(&archive_path).await.is_err() {
249 missing.insert(record.uuid.clone());
250 }
251 }
252 missing
253}
254
255pub async fn fetch_missing_sources(
269 manifest: &PatchManifest,
270 sources: &PatchSources<'_>,
271 mode: DownloadMode,
272 client: &ApiClient,
273 on_progress: Option<&OnProgress>,
274) -> FetchMissingBlobsResult {
275 match mode {
276 DownloadMode::File => {
277 fetch_missing_blobs(manifest, sources.blobs_path, client, on_progress).await
278 }
279 DownloadMode::Diff => match sources.diffs_path {
280 Some(dir) => {
281 fetch_missing_archives_inner(manifest, dir, ArchiveKind::Diff, client, on_progress)
282 .await
283 }
284 None => empty_result(),
285 },
286 DownloadMode::Package => match sources.packages_path {
287 Some(dir) => fetch_missing_archives_inner(
288 manifest,
289 dir,
290 ArchiveKind::Package,
291 client,
292 on_progress,
293 )
294 .await,
295 None => empty_result(),
296 },
297 }
298}
299
300#[derive(Debug, Clone, Copy)]
301enum ArchiveKind {
302 Diff,
303 Package,
304}
305
306fn empty_result() -> FetchMissingBlobsResult {
307 FetchMissingBlobsResult {
308 total: 0,
309 downloaded: 0,
310 failed: 0,
311 skipped: 0,
312 results: Vec::new(),
313 }
314}
315
316async fn fetch_missing_archives_inner(
317 manifest: &PatchManifest,
318 archives_dir: &Path,
319 kind: ArchiveKind,
320 client: &ApiClient,
321 on_progress: Option<&OnProgress>,
322) -> FetchMissingBlobsResult {
323 let missing = get_missing_archives(manifest, archives_dir).await;
324 if missing.is_empty() {
325 return empty_result();
326 }
327
328 if let Err(e) = tokio::fs::create_dir_all(archives_dir).await {
329 let results: Vec<BlobFetchResult> = missing
330 .iter()
331 .map(|u| BlobFetchResult {
332 hash: u.clone(),
333 success: false,
334 error: Some(format!("Cannot create archives directory: {}", e)),
335 })
336 .collect();
337 let failed = results.len();
338 return FetchMissingBlobsResult {
339 total: failed,
340 downloaded: 0,
341 failed,
342 skipped: 0,
343 results,
344 };
345 }
346
347 let uuids: Vec<String> = missing.into_iter().collect();
348 let total = uuids.len();
349 let mut downloaded = 0usize;
350 let mut failed = 0usize;
351 let mut results = Vec::with_capacity(total);
352
353 for (i, uuid) in uuids.iter().enumerate() {
354 if let Some(ref cb) = on_progress {
355 cb(uuid, i + 1, total);
356 }
357
358 let fetch_result = match kind {
359 ArchiveKind::Diff => client.fetch_diff(uuid).await,
360 ArchiveKind::Package => client.fetch_package(uuid).await,
361 };
362
363 match fetch_result {
364 Ok(Some(data)) => {
365 let archive_path: PathBuf = archives_dir.join(format!("{}.tar.gz", uuid));
366 match tokio::fs::write(&archive_path, &data).await {
367 Ok(()) => {
368 results.push(BlobFetchResult {
369 hash: uuid.clone(),
370 success: true,
371 error: None,
372 });
373 downloaded += 1;
374 }
375 Err(e) => {
376 results.push(BlobFetchResult {
377 hash: uuid.clone(),
378 success: false,
379 error: Some(format!("Failed to write archive to disk: {}", e)),
380 });
381 failed += 1;
382 }
383 }
384 }
385 Ok(None) => {
386 results.push(BlobFetchResult {
387 hash: uuid.clone(),
388 success: false,
389 error: Some(format!(
390 "{} archive not found on server",
391 match kind {
392 ArchiveKind::Diff => "Diff",
393 ArchiveKind::Package => "Package",
394 }
395 )),
396 });
397 failed += 1;
398 }
399 Err(e) => {
400 results.push(BlobFetchResult {
401 hash: uuid.clone(),
402 success: false,
403 error: Some(e.to_string()),
404 });
405 failed += 1;
406 }
407 }
408 }
409
410 FetchMissingBlobsResult {
411 total,
412 downloaded,
413 failed,
414 skipped: 0,
415 results,
416 }
417}
418
419pub fn format_fetch_result(result: &FetchMissingBlobsResult) -> String {
421 if result.total == 0 {
422 return "All blobs are present locally.".to_string();
423 }
424
425 let mut lines: Vec<String> = Vec::new();
426
427 if result.downloaded > 0 {
428 lines.push(format!("Downloaded {} blob(s)", result.downloaded));
429 }
430
431 if result.failed > 0 {
432 lines.push(format!("Failed to download {} blob(s)", result.failed));
433
434 let failed_results: Vec<&BlobFetchResult> =
435 result.results.iter().filter(|r| !r.success).collect();
436
437 for r in failed_results.iter().take(5) {
438 let short_hash = if r.hash.len() >= 12 {
439 &r.hash[..12]
440 } else {
441 &r.hash
442 };
443 let err = r.error.as_deref().unwrap_or("unknown error");
444 lines.push(format!(" - {}...: {}", short_hash, err));
445 }
446
447 if failed_results.len() > 5 {
448 lines.push(format!(" ... and {} more", failed_results.len() - 5));
449 }
450 }
451
452 lines.join("\n")
453}
454
455async fn download_hashes(
460 hashes: &[String],
461 blobs_path: &Path,
462 client: &ApiClient,
463 on_progress: Option<&OnProgress>,
464) -> FetchMissingBlobsResult {
465 let total = hashes.len();
466 let mut downloaded: usize = 0;
467 let mut failed: usize = 0;
468 let mut results: Vec<BlobFetchResult> = Vec::with_capacity(total);
469
470 for (i, hash) in hashes.iter().enumerate() {
471 if let Some(ref cb) = on_progress {
472 cb(hash, i + 1, total);
473 }
474
475 match client.fetch_blob(hash).await {
476 Ok(Some(data)) => {
477 let actual_hash = crate::hash::git_sha256::compute_git_sha256_from_bytes(&data);
479 if actual_hash != *hash {
480 results.push(BlobFetchResult {
481 hash: hash.clone(),
482 success: false,
483 error: Some(format!(
484 "Content hash mismatch: expected {}, got {}",
485 hash, actual_hash
486 )),
487 });
488 failed += 1;
489 continue;
490 }
491
492 let blob_path: PathBuf = blobs_path.join(hash);
493 match tokio::fs::write(&blob_path, &data).await {
494 Ok(()) => {
495 results.push(BlobFetchResult {
496 hash: hash.clone(),
497 success: true,
498 error: None,
499 });
500 downloaded += 1;
501 }
502 Err(e) => {
503 results.push(BlobFetchResult {
504 hash: hash.clone(),
505 success: false,
506 error: Some(format!("Failed to write blob to disk: {}", e)),
507 });
508 failed += 1;
509 }
510 }
511 }
512 Ok(None) => {
513 results.push(BlobFetchResult {
514 hash: hash.clone(),
515 success: false,
516 error: Some("Blob not found on server".to_string()),
517 });
518 failed += 1;
519 }
520 Err(e) => {
521 results.push(BlobFetchResult {
522 hash: hash.clone(),
523 success: false,
524 error: Some(e.to_string()),
525 });
526 failed += 1;
527 }
528 }
529 }
530
531 FetchMissingBlobsResult {
532 total,
533 downloaded,
534 failed,
535 skipped: 0,
536 results,
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord};
544 use std::collections::HashMap;
545
546 fn make_manifest_with_hashes(after_hashes: &[&str]) -> PatchManifest {
547 let mut files = HashMap::new();
548 for (i, ah) in after_hashes.iter().enumerate() {
549 files.insert(
550 format!("package/file{}.js", i),
551 PatchFileInfo {
552 before_hash: format!(
553 "before{}{}",
554 "0".repeat(58),
555 format!("{:06}", i)
556 ),
557 after_hash: ah.to_string(),
558 },
559 );
560 }
561
562 let mut patches = HashMap::new();
563 patches.insert(
564 "pkg:npm/test@1.0.0".to_string(),
565 PatchRecord {
566 uuid: "test-uuid".to_string(),
567 exported_at: "2024-01-01T00:00:00Z".to_string(),
568 files,
569 vulnerabilities: HashMap::new(),
570 description: "test".to_string(),
571 license: "MIT".to_string(),
572 tier: "free".to_string(),
573 },
574 );
575
576 PatchManifest { patches }
577 }
578
579 #[tokio::test]
580 async fn test_get_missing_blobs_all_missing() {
581 let dir = tempfile::tempdir().unwrap();
582 let blobs_path = dir.path().join("blobs");
583 tokio::fs::create_dir_all(&blobs_path).await.unwrap();
584
585 let h1 = "a".repeat(64);
586 let h2 = "b".repeat(64);
587 let manifest = make_manifest_with_hashes(&[&h1, &h2]);
588
589 let missing = get_missing_blobs(&manifest, &blobs_path).await;
590 assert_eq!(missing.len(), 2);
591 assert!(missing.contains(&h1));
592 assert!(missing.contains(&h2));
593 }
594
595 #[tokio::test]
596 async fn test_get_missing_blobs_some_present() {
597 let dir = tempfile::tempdir().unwrap();
598 let blobs_path = dir.path().join("blobs");
599 tokio::fs::create_dir_all(&blobs_path).await.unwrap();
600
601 let h1 = "a".repeat(64);
602 let h2 = "b".repeat(64);
603
604 tokio::fs::write(blobs_path.join(&h1), b"data").await.unwrap();
606
607 let manifest = make_manifest_with_hashes(&[&h1, &h2]);
608 let missing = get_missing_blobs(&manifest, &blobs_path).await;
609 assert_eq!(missing.len(), 1);
610 assert!(missing.contains(&h2));
611 assert!(!missing.contains(&h1));
612 }
613
614 #[tokio::test]
615 async fn test_get_missing_blobs_empty_manifest() {
616 let dir = tempfile::tempdir().unwrap();
617 let blobs_path = dir.path().join("blobs");
618 tokio::fs::create_dir_all(&blobs_path).await.unwrap();
619
620 let manifest = PatchManifest::new();
621 let missing = get_missing_blobs(&manifest, &blobs_path).await;
622 assert!(missing.is_empty());
623 }
624
625 #[test]
626 fn test_format_fetch_result_all_present() {
627 let result = FetchMissingBlobsResult {
628 total: 0,
629 downloaded: 0,
630 failed: 0,
631 skipped: 0,
632 results: Vec::new(),
633 };
634 assert_eq!(format_fetch_result(&result), "All blobs are present locally.");
635 }
636
637 #[test]
638 fn test_format_fetch_result_some_downloaded() {
639 let result = FetchMissingBlobsResult {
640 total: 3,
641 downloaded: 2,
642 failed: 1,
643 skipped: 0,
644 results: vec![
645 BlobFetchResult {
646 hash: "a".repeat(64),
647 success: true,
648 error: None,
649 },
650 BlobFetchResult {
651 hash: "b".repeat(64),
652 success: true,
653 error: None,
654 },
655 BlobFetchResult {
656 hash: "c".repeat(64),
657 success: false,
658 error: Some("Blob not found on server".to_string()),
659 },
660 ],
661 };
662 let output = format_fetch_result(&result);
663 assert!(output.contains("Downloaded 2 blob(s)"));
664 assert!(output.contains("Failed to download 1 blob(s)"));
665 assert!(output.contains("cccccccccccc..."));
666 assert!(output.contains("Blob not found on server"));
667 }
668
669 #[test]
670 fn test_format_fetch_result_truncates_at_5() {
671 let results: Vec<BlobFetchResult> = (0..8)
672 .map(|i| BlobFetchResult {
673 hash: format!("{:0>64}", i),
674 success: false,
675 error: Some(format!("error {}", i)),
676 })
677 .collect();
678
679 let result = FetchMissingBlobsResult {
680 total: 8,
681 downloaded: 0,
682 failed: 8,
683 skipped: 0,
684 results,
685 };
686 let output = format_fetch_result(&result);
687 assert!(output.contains("... and 3 more"));
688 }
689
690 #[test]
693 fn test_format_only_downloaded() {
694 let result = FetchMissingBlobsResult {
695 total: 3,
696 downloaded: 3,
697 failed: 0,
698 skipped: 0,
699 results: vec![
700 BlobFetchResult { hash: "a".repeat(64), success: true, error: None },
701 BlobFetchResult { hash: "b".repeat(64), success: true, error: None },
702 BlobFetchResult { hash: "c".repeat(64), success: true, error: None },
703 ],
704 };
705 let output = format_fetch_result(&result);
706 assert!(output.contains("Downloaded 3 blob(s)"));
707 assert!(!output.contains("Failed"));
708 }
709
710 #[test]
711 fn test_format_short_hash() {
712 let result = FetchMissingBlobsResult {
713 total: 1,
714 downloaded: 0,
715 failed: 1,
716 skipped: 0,
717 results: vec![BlobFetchResult {
718 hash: "abc".into(),
719 success: false,
720 error: Some("not found".into()),
721 }],
722 };
723 let output = format_fetch_result(&result);
724 assert!(output.contains("abc..."));
726 }
727
728 #[test]
729 fn test_format_error_none() {
730 let result = FetchMissingBlobsResult {
731 total: 1,
732 downloaded: 0,
733 failed: 1,
734 skipped: 0,
735 results: vec![BlobFetchResult {
736 hash: "d".repeat(64),
737 success: false,
738 error: None,
739 }],
740 };
741 let output = format_fetch_result(&result);
742 assert!(output.contains("unknown error"));
743 }
744
745 #[test]
748 fn test_download_mode_parse() {
749 assert_eq!(DownloadMode::parse("diff").unwrap(), DownloadMode::Diff);
750 assert_eq!(DownloadMode::parse("DIFF").unwrap(), DownloadMode::Diff);
751 assert_eq!(
752 DownloadMode::parse("package").unwrap(),
753 DownloadMode::Package
754 );
755 assert_eq!(DownloadMode::parse("file").unwrap(), DownloadMode::File);
756 assert_eq!(DownloadMode::parse("blob").unwrap(), DownloadMode::File);
758 assert!(DownloadMode::parse("nope").is_err());
759 }
760
761 #[test]
762 fn test_download_mode_tag() {
763 assert_eq!(DownloadMode::Diff.as_tag(), "diff");
764 assert_eq!(DownloadMode::Package.as_tag(), "package");
765 assert_eq!(DownloadMode::File.as_tag(), "file");
766 }
767
768 fn make_manifest_with_uuids(uuids: &[&str]) -> PatchManifest {
769 let mut patches = HashMap::new();
770 for (i, uuid) in uuids.iter().enumerate() {
771 let key = format!("pkg:npm/test-{}@1.0.0", i);
772 patches.insert(
773 key,
774 PatchRecord {
775 uuid: (*uuid).to_string(),
776 exported_at: "2024-01-01T00:00:00Z".to_string(),
777 files: HashMap::new(),
778 vulnerabilities: HashMap::new(),
779 description: "test".to_string(),
780 license: "MIT".to_string(),
781 tier: "free".to_string(),
782 },
783 );
784 }
785 PatchManifest { patches }
786 }
787
788 #[tokio::test]
789 async fn test_get_missing_archives_all_missing() {
790 let dir = tempfile::tempdir().unwrap();
791 let archives = dir.path().join("packages");
792 tokio::fs::create_dir_all(&archives).await.unwrap();
793
794 let u1 = "11111111-1111-4111-8111-111111111111";
795 let u2 = "22222222-2222-4222-8222-222222222222";
796 let manifest = make_manifest_with_uuids(&[u1, u2]);
797
798 let missing = get_missing_archives(&manifest, &archives).await;
799 assert_eq!(missing.len(), 2);
800 assert!(missing.contains(u1));
801 assert!(missing.contains(u2));
802 }
803
804 #[tokio::test]
805 async fn test_get_missing_archives_some_present() {
806 let dir = tempfile::tempdir().unwrap();
807 let archives = dir.path().join("packages");
808 tokio::fs::create_dir_all(&archives).await.unwrap();
809
810 let u1 = "11111111-1111-4111-8111-111111111111";
811 let u2 = "22222222-2222-4222-8222-222222222222";
812
813 tokio::fs::write(archives.join(format!("{u1}.tar.gz")), b"data")
814 .await
815 .unwrap();
816
817 let manifest = make_manifest_with_uuids(&[u1, u2]);
818 let missing = get_missing_archives(&manifest, &archives).await;
819 assert_eq!(missing.len(), 1);
820 assert!(missing.contains(u2));
821 assert!(!missing.contains(u1));
822 }
823
824 #[tokio::test]
825 async fn test_fetch_missing_sources_unsupported_mode_returns_empty() {
826 let dir = tempfile::tempdir().unwrap();
829 let blobs = dir.path().join("blobs");
830 tokio::fs::create_dir_all(&blobs).await.unwrap();
831 let sources = PatchSources::blobs_only(&blobs);
832
833 let manifest = make_manifest_with_uuids(&["11111111-1111-4111-8111-111111111111"]);
834 let (client, _) = crate::api::client::get_api_client_from_env(None).await;
835
836 let res = fetch_missing_sources(&manifest, &sources, DownloadMode::Diff, &client, None)
837 .await;
838 assert_eq!(res.total, 0);
839 assert_eq!(res.downloaded, 0);
840 assert_eq!(res.failed, 0);
841
842 let res = fetch_missing_sources(&manifest, &sources, DownloadMode::Package, &client, None)
843 .await;
844 assert_eq!(res.total, 0);
845 }
846
847 #[test]
848 fn test_format_only_failed() {
849 let result = FetchMissingBlobsResult {
850 total: 2,
851 downloaded: 0,
852 failed: 2,
853 skipped: 0,
854 results: vec![
855 BlobFetchResult {
856 hash: "a".repeat(64),
857 success: false,
858 error: Some("timeout".into()),
859 },
860 BlobFetchResult {
861 hash: "b".repeat(64),
862 success: false,
863 error: Some("timeout".into()),
864 },
865 ],
866 };
867 let output = format_fetch_result(&result);
868 assert!(!output.contains("Downloaded"));
869 assert!(output.contains("Failed to download 2 blob(s)"));
870 }
871}