Skip to main content

socket_patch_core/api/
blob_fetcher.rs

1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3
4use crate::api::client::ApiClient;
5use crate::manifest::operations::get_after_hash_blobs;
6use crate::manifest::schema::PatchManifest;
7
8/// Result of fetching a single blob.
9#[derive(Debug, Clone)]
10pub struct BlobFetchResult {
11    pub hash: String,
12    pub success: bool,
13    pub error: Option<String>,
14}
15
16/// Aggregate result of a blob-fetch operation.
17#[derive(Debug, Clone)]
18pub struct FetchMissingBlobsResult {
19    pub total: usize,
20    pub downloaded: usize,
21    pub failed: usize,
22    pub skipped: usize,
23    pub results: Vec<BlobFetchResult>,
24}
25
26/// Progress callback signature.
27///
28/// Called with `(hash, one_based_index, total)` for each blob.
29pub type OnProgress = Box<dyn Fn(&str, usize, usize) + Send + Sync>;
30
31// ── Public API ────────────────────────────────────────────────────────
32
33/// Determine which `afterHash` blobs referenced in the manifest are
34/// missing from disk.
35///
36/// Only checks `afterHash` blobs because those are the patched file
37/// contents needed for applying patches. `beforeHash` blobs are
38/// downloaded on-demand during rollback.
39pub async fn get_missing_blobs(
40    manifest: &PatchManifest,
41    blobs_path: &Path,
42) -> HashSet<String> {
43    let after_hash_blobs = get_after_hash_blobs(manifest);
44    let mut missing = HashSet::new();
45
46    for hash in after_hash_blobs {
47        let blob_path = blobs_path.join(&hash);
48        if tokio::fs::metadata(&blob_path).await.is_err() {
49            missing.insert(hash);
50        }
51    }
52
53    missing
54}
55
56/// Download all missing `afterHash` blobs referenced in the manifest.
57///
58/// Creates the `blobs_path` directory if it does not exist.
59///
60/// # Arguments
61///
62/// * `manifest`    – Patch manifest whose `afterHash` blobs to check.
63/// * `blobs_path`  – Directory where blob files are stored (one file per
64///   hash).
65/// * `client`      – [`ApiClient`] used to fetch blobs from the server.
66/// * `on_progress` – Optional callback invoked before each download with
67///   `(hash, 1-based index, total)`.
68pub async fn fetch_missing_blobs(
69    manifest: &PatchManifest,
70    blobs_path: &Path,
71    client: &ApiClient,
72    on_progress: Option<&OnProgress>,
73) -> FetchMissingBlobsResult {
74    let missing = get_missing_blobs(manifest, blobs_path).await;
75
76    if missing.is_empty() {
77        return FetchMissingBlobsResult {
78            total: 0,
79            downloaded: 0,
80            failed: 0,
81            skipped: 0,
82            results: Vec::new(),
83        };
84    }
85
86    // Ensure blobs directory exists
87    if let Err(e) = tokio::fs::create_dir_all(blobs_path).await {
88        // If we cannot create the directory, every blob will fail.
89        let results: Vec<BlobFetchResult> = missing
90            .iter()
91            .map(|h| BlobFetchResult {
92                hash: h.clone(),
93                success: false,
94                error: Some(format!("Cannot create blobs directory: {}", e)),
95            })
96            .collect();
97        let failed = results.len();
98        return FetchMissingBlobsResult {
99            total: failed,
100            downloaded: 0,
101            failed,
102            skipped: 0,
103            results,
104        };
105    }
106
107    let hashes: Vec<String> = missing.into_iter().collect();
108    download_hashes(&hashes, blobs_path, client, on_progress).await
109}
110
111/// Download specific blobs identified by their hashes.
112///
113/// Useful for fetching `beforeHash` blobs during rollback, where only a
114/// subset of hashes is required.
115///
116/// Blobs that already exist on disk are skipped (counted in `skipped`).
117pub async fn fetch_blobs_by_hash(
118    hashes: &HashSet<String>,
119    blobs_path: &Path,
120    client: &ApiClient,
121    on_progress: Option<&OnProgress>,
122) -> FetchMissingBlobsResult {
123    if hashes.is_empty() {
124        return FetchMissingBlobsResult {
125            total: 0,
126            downloaded: 0,
127            failed: 0,
128            skipped: 0,
129            results: Vec::new(),
130        };
131    }
132
133    // Ensure blobs directory exists
134    if let Err(e) = tokio::fs::create_dir_all(blobs_path).await {
135        let results: Vec<BlobFetchResult> = hashes
136            .iter()
137            .map(|h| BlobFetchResult {
138                hash: h.clone(),
139                success: false,
140                error: Some(format!("Cannot create blobs directory: {}", e)),
141            })
142            .collect();
143        let failed = results.len();
144        return FetchMissingBlobsResult {
145            total: failed,
146            downloaded: 0,
147            failed,
148            skipped: 0,
149            results,
150        };
151    }
152
153    // Filter out hashes that already exist on disk
154    let mut to_download: Vec<String> = Vec::new();
155    let mut skipped: usize = 0;
156    let mut results: Vec<BlobFetchResult> = Vec::new();
157
158    for hash in hashes {
159        let blob_path = blobs_path.join(hash);
160        if tokio::fs::metadata(&blob_path).await.is_ok() {
161            skipped += 1;
162            results.push(BlobFetchResult {
163                hash: hash.clone(),
164                success: true,
165                error: None,
166            });
167        } else {
168            to_download.push(hash.clone());
169        }
170    }
171
172    if to_download.is_empty() {
173        return FetchMissingBlobsResult {
174            total: hashes.len(),
175            downloaded: 0,
176            failed: 0,
177            skipped,
178            results,
179        };
180    }
181
182    let download_result =
183        download_hashes(&to_download, blobs_path, client, on_progress).await;
184
185    FetchMissingBlobsResult {
186        total: hashes.len(),
187        downloaded: download_result.downloaded,
188        failed: download_result.failed,
189        skipped,
190        results: {
191            let mut combined = results;
192            combined.extend(download_result.results);
193            combined
194        },
195    }
196}
197
198/// Format a [`FetchMissingBlobsResult`] as a human-readable string.
199pub fn format_fetch_result(result: &FetchMissingBlobsResult) -> String {
200    if result.total == 0 {
201        return "All blobs are present locally.".to_string();
202    }
203
204    let mut lines: Vec<String> = Vec::new();
205
206    if result.downloaded > 0 {
207        lines.push(format!("Downloaded {} blob(s)", result.downloaded));
208    }
209
210    if result.failed > 0 {
211        lines.push(format!("Failed to download {} blob(s)", result.failed));
212
213        let failed_results: Vec<&BlobFetchResult> =
214            result.results.iter().filter(|r| !r.success).collect();
215
216        for r in failed_results.iter().take(5) {
217            let short_hash = if r.hash.len() >= 12 {
218                &r.hash[..12]
219            } else {
220                &r.hash
221            };
222            let err = r.error.as_deref().unwrap_or("unknown error");
223            lines.push(format!("  - {}...: {}", short_hash, err));
224        }
225
226        if failed_results.len() > 5 {
227            lines.push(format!("  ... and {} more", failed_results.len() - 5));
228        }
229    }
230
231    lines.join("\n")
232}
233
234// ── Internal helpers ──────────────────────────────────────────────────
235
236/// Download a list of blob hashes sequentially, writing each to
237/// `blobs_path/<hash>`.
238async fn download_hashes(
239    hashes: &[String],
240    blobs_path: &Path,
241    client: &ApiClient,
242    on_progress: Option<&OnProgress>,
243) -> FetchMissingBlobsResult {
244    let total = hashes.len();
245    let mut downloaded: usize = 0;
246    let mut failed: usize = 0;
247    let mut results: Vec<BlobFetchResult> = Vec::with_capacity(total);
248
249    for (i, hash) in hashes.iter().enumerate() {
250        if let Some(ref cb) = on_progress {
251            cb(hash, i + 1, total);
252        }
253
254        match client.fetch_blob(hash).await {
255            Ok(Some(data)) => {
256                let blob_path: PathBuf = blobs_path.join(hash);
257                match tokio::fs::write(&blob_path, &data).await {
258                    Ok(()) => {
259                        results.push(BlobFetchResult {
260                            hash: hash.clone(),
261                            success: true,
262                            error: None,
263                        });
264                        downloaded += 1;
265                    }
266                    Err(e) => {
267                        results.push(BlobFetchResult {
268                            hash: hash.clone(),
269                            success: false,
270                            error: Some(format!("Failed to write blob to disk: {}", e)),
271                        });
272                        failed += 1;
273                    }
274                }
275            }
276            Ok(None) => {
277                results.push(BlobFetchResult {
278                    hash: hash.clone(),
279                    success: false,
280                    error: Some("Blob not found on server".to_string()),
281                });
282                failed += 1;
283            }
284            Err(e) => {
285                results.push(BlobFetchResult {
286                    hash: hash.clone(),
287                    success: false,
288                    error: Some(e.to_string()),
289                });
290                failed += 1;
291            }
292        }
293    }
294
295    FetchMissingBlobsResult {
296        total,
297        downloaded,
298        failed,
299        skipped: 0,
300        results,
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord};
308    use std::collections::HashMap;
309
310    fn make_manifest_with_hashes(after_hashes: &[&str]) -> PatchManifest {
311        let mut files = HashMap::new();
312        for (i, ah) in after_hashes.iter().enumerate() {
313            files.insert(
314                format!("package/file{}.js", i),
315                PatchFileInfo {
316                    before_hash: format!(
317                        "before{}{}",
318                        "0".repeat(58),
319                        format!("{:06}", i)
320                    ),
321                    after_hash: ah.to_string(),
322                },
323            );
324        }
325
326        let mut patches = HashMap::new();
327        patches.insert(
328            "pkg:npm/test@1.0.0".to_string(),
329            PatchRecord {
330                uuid: "test-uuid".to_string(),
331                exported_at: "2024-01-01T00:00:00Z".to_string(),
332                files,
333                vulnerabilities: HashMap::new(),
334                description: "test".to_string(),
335                license: "MIT".to_string(),
336                tier: "free".to_string(),
337            },
338        );
339
340        PatchManifest { patches }
341    }
342
343    #[tokio::test]
344    async fn test_get_missing_blobs_all_missing() {
345        let dir = tempfile::tempdir().unwrap();
346        let blobs_path = dir.path().join("blobs");
347        tokio::fs::create_dir_all(&blobs_path).await.unwrap();
348
349        let h1 = "a".repeat(64);
350        let h2 = "b".repeat(64);
351        let manifest = make_manifest_with_hashes(&[&h1, &h2]);
352
353        let missing = get_missing_blobs(&manifest, &blobs_path).await;
354        assert_eq!(missing.len(), 2);
355        assert!(missing.contains(&h1));
356        assert!(missing.contains(&h2));
357    }
358
359    #[tokio::test]
360    async fn test_get_missing_blobs_some_present() {
361        let dir = tempfile::tempdir().unwrap();
362        let blobs_path = dir.path().join("blobs");
363        tokio::fs::create_dir_all(&blobs_path).await.unwrap();
364
365        let h1 = "a".repeat(64);
366        let h2 = "b".repeat(64);
367
368        // Write h1 to disk so it is NOT missing
369        tokio::fs::write(blobs_path.join(&h1), b"data").await.unwrap();
370
371        let manifest = make_manifest_with_hashes(&[&h1, &h2]);
372        let missing = get_missing_blobs(&manifest, &blobs_path).await;
373        assert_eq!(missing.len(), 1);
374        assert!(missing.contains(&h2));
375        assert!(!missing.contains(&h1));
376    }
377
378    #[tokio::test]
379    async fn test_get_missing_blobs_empty_manifest() {
380        let dir = tempfile::tempdir().unwrap();
381        let blobs_path = dir.path().join("blobs");
382        tokio::fs::create_dir_all(&blobs_path).await.unwrap();
383
384        let manifest = PatchManifest::new();
385        let missing = get_missing_blobs(&manifest, &blobs_path).await;
386        assert!(missing.is_empty());
387    }
388
389    #[test]
390    fn test_format_fetch_result_all_present() {
391        let result = FetchMissingBlobsResult {
392            total: 0,
393            downloaded: 0,
394            failed: 0,
395            skipped: 0,
396            results: Vec::new(),
397        };
398        assert_eq!(format_fetch_result(&result), "All blobs are present locally.");
399    }
400
401    #[test]
402    fn test_format_fetch_result_some_downloaded() {
403        let result = FetchMissingBlobsResult {
404            total: 3,
405            downloaded: 2,
406            failed: 1,
407            skipped: 0,
408            results: vec![
409                BlobFetchResult {
410                    hash: "a".repeat(64),
411                    success: true,
412                    error: None,
413                },
414                BlobFetchResult {
415                    hash: "b".repeat(64),
416                    success: true,
417                    error: None,
418                },
419                BlobFetchResult {
420                    hash: "c".repeat(64),
421                    success: false,
422                    error: Some("Blob not found on server".to_string()),
423                },
424            ],
425        };
426        let output = format_fetch_result(&result);
427        assert!(output.contains("Downloaded 2 blob(s)"));
428        assert!(output.contains("Failed to download 1 blob(s)"));
429        assert!(output.contains("cccccccccccc..."));
430        assert!(output.contains("Blob not found on server"));
431    }
432
433    #[test]
434    fn test_format_fetch_result_truncates_at_5() {
435        let results: Vec<BlobFetchResult> = (0..8)
436            .map(|i| BlobFetchResult {
437                hash: format!("{:0>64}", i),
438                success: false,
439                error: Some(format!("error {}", i)),
440            })
441            .collect();
442
443        let result = FetchMissingBlobsResult {
444            total: 8,
445            downloaded: 0,
446            failed: 8,
447            skipped: 0,
448            results,
449        };
450        let output = format_fetch_result(&result);
451        assert!(output.contains("... and 3 more"));
452    }
453
454    // ── Group 8: format edge cases ───────────────────────────────────
455
456    #[test]
457    fn test_format_only_downloaded() {
458        let result = FetchMissingBlobsResult {
459            total: 3,
460            downloaded: 3,
461            failed: 0,
462            skipped: 0,
463            results: vec![
464                BlobFetchResult { hash: "a".repeat(64), success: true, error: None },
465                BlobFetchResult { hash: "b".repeat(64), success: true, error: None },
466                BlobFetchResult { hash: "c".repeat(64), success: true, error: None },
467            ],
468        };
469        let output = format_fetch_result(&result);
470        assert!(output.contains("Downloaded 3 blob(s)"));
471        assert!(!output.contains("Failed"));
472    }
473
474    #[test]
475    fn test_format_short_hash() {
476        let result = FetchMissingBlobsResult {
477            total: 1,
478            downloaded: 0,
479            failed: 1,
480            skipped: 0,
481            results: vec![BlobFetchResult {
482                hash: "abc".into(),
483                success: false,
484                error: Some("not found".into()),
485            }],
486        };
487        let output = format_fetch_result(&result);
488        // Hash is < 12 chars, should show full hash
489        assert!(output.contains("abc..."));
490    }
491
492    #[test]
493    fn test_format_error_none() {
494        let result = FetchMissingBlobsResult {
495            total: 1,
496            downloaded: 0,
497            failed: 1,
498            skipped: 0,
499            results: vec![BlobFetchResult {
500                hash: "d".repeat(64),
501                success: false,
502                error: None,
503            }],
504        };
505        let output = format_fetch_result(&result);
506        assert!(output.contains("unknown error"));
507    }
508
509    #[test]
510    fn test_format_only_failed() {
511        let result = FetchMissingBlobsResult {
512            total: 2,
513            downloaded: 0,
514            failed: 2,
515            skipped: 0,
516            results: vec![
517                BlobFetchResult {
518                    hash: "a".repeat(64),
519                    success: false,
520                    error: Some("timeout".into()),
521                },
522                BlobFetchResult {
523                    hash: "b".repeat(64),
524                    success: false,
525                    error: Some("timeout".into()),
526                },
527            ],
528        };
529        let output = format_fetch_result(&result);
530        assert!(!output.contains("Downloaded"));
531        assert!(output.contains("Failed to download 2 blob(s)"));
532    }
533}