Skip to main content

socket_patch_core/api/
blob_fetcher.rs

1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3
4use crate::api::client::ApiClient;
5use crate::manifest::operations::get_after_hash_blobs;
6use crate::manifest::schema::PatchManifest;
7
8/// Result of fetching a single blob.
9#[derive(Debug, Clone)]
10pub struct BlobFetchResult {
11    pub hash: String,
12    pub success: bool,
13    pub error: Option<String>,
14}
15
16/// Aggregate result of a blob-fetch operation.
17#[derive(Debug, Clone)]
18pub struct FetchMissingBlobsResult {
19    pub total: usize,
20    pub downloaded: usize,
21    pub failed: usize,
22    pub skipped: usize,
23    pub results: Vec<BlobFetchResult>,
24}
25
26/// Progress callback signature.
27///
28/// Called with `(hash, one_based_index, total)` for each blob.
29pub type OnProgress = Box<dyn Fn(&str, usize, usize) + Send + Sync>;
30
31// ── Public API ────────────────────────────────────────────────────────
32
33/// Determine which `afterHash` blobs referenced in the manifest are
34/// missing from disk.
35///
36/// Only checks `afterHash` blobs because those are the patched file
37/// contents needed for applying patches. `beforeHash` blobs are
38/// downloaded on-demand during rollback.
39pub async fn get_missing_blobs(
40    manifest: &PatchManifest,
41    blobs_path: &Path,
42) -> HashSet<String> {
43    let after_hash_blobs = get_after_hash_blobs(manifest);
44    let mut missing = HashSet::new();
45
46    for hash in after_hash_blobs {
47        let blob_path = blobs_path.join(&hash);
48        if tokio::fs::metadata(&blob_path).await.is_err() {
49            missing.insert(hash);
50        }
51    }
52
53    missing
54}
55
56/// Download all missing `afterHash` blobs referenced in the manifest.
57///
58/// Creates the `blobs_path` directory if it does not exist.
59///
60/// # Arguments
61///
62/// * `manifest`    – Patch manifest whose `afterHash` blobs to check.
63/// * `blobs_path`  – Directory where blob files are stored (one file per
64///   hash).
65/// * `client`      – [`ApiClient`] used to fetch blobs from the server.
66/// * `on_progress` – Optional callback invoked before each download with
67///   `(hash, 1-based index, total)`.
68pub async fn fetch_missing_blobs(
69    manifest: &PatchManifest,
70    blobs_path: &Path,
71    client: &ApiClient,
72    on_progress: Option<&OnProgress>,
73) -> FetchMissingBlobsResult {
74    let missing = get_missing_blobs(manifest, blobs_path).await;
75
76    if missing.is_empty() {
77        return FetchMissingBlobsResult {
78            total: 0,
79            downloaded: 0,
80            failed: 0,
81            skipped: 0,
82            results: Vec::new(),
83        };
84    }
85
86    // Ensure blobs directory exists
87    if let Err(e) = tokio::fs::create_dir_all(blobs_path).await {
88        // If we cannot create the directory, every blob will fail.
89        let results: Vec<BlobFetchResult> = missing
90            .iter()
91            .map(|h| BlobFetchResult {
92                hash: h.clone(),
93                success: false,
94                error: Some(format!("Cannot create blobs directory: {}", e)),
95            })
96            .collect();
97        let failed = results.len();
98        return FetchMissingBlobsResult {
99            total: failed,
100            downloaded: 0,
101            failed,
102            skipped: 0,
103            results,
104        };
105    }
106
107    let hashes: Vec<String> = missing.into_iter().collect();
108    download_hashes(&hashes, blobs_path, client, on_progress).await
109}
110
111/// Download specific blobs identified by their hashes.
112///
113/// Useful for fetching `beforeHash` blobs during rollback, where only a
114/// subset of hashes is required.
115///
116/// Blobs that already exist on disk are skipped (counted in `skipped`).
117pub async fn fetch_blobs_by_hash(
118    hashes: &HashSet<String>,
119    blobs_path: &Path,
120    client: &ApiClient,
121    on_progress: Option<&OnProgress>,
122) -> FetchMissingBlobsResult {
123    if hashes.is_empty() {
124        return FetchMissingBlobsResult {
125            total: 0,
126            downloaded: 0,
127            failed: 0,
128            skipped: 0,
129            results: Vec::new(),
130        };
131    }
132
133    // Ensure blobs directory exists
134    if let Err(e) = tokio::fs::create_dir_all(blobs_path).await {
135        let results: Vec<BlobFetchResult> = hashes
136            .iter()
137            .map(|h| BlobFetchResult {
138                hash: h.clone(),
139                success: false,
140                error: Some(format!("Cannot create blobs directory: {}", e)),
141            })
142            .collect();
143        let failed = results.len();
144        return FetchMissingBlobsResult {
145            total: failed,
146            downloaded: 0,
147            failed,
148            skipped: 0,
149            results,
150        };
151    }
152
153    // Filter out hashes that already exist on disk
154    let mut to_download: Vec<String> = Vec::new();
155    let mut skipped: usize = 0;
156    let mut results: Vec<BlobFetchResult> = Vec::new();
157
158    for hash in hashes {
159        let blob_path = blobs_path.join(hash);
160        if tokio::fs::metadata(&blob_path).await.is_ok() {
161            skipped += 1;
162            results.push(BlobFetchResult {
163                hash: hash.clone(),
164                success: true,
165                error: None,
166            });
167        } else {
168            to_download.push(hash.clone());
169        }
170    }
171
172    if to_download.is_empty() {
173        return FetchMissingBlobsResult {
174            total: hashes.len(),
175            downloaded: 0,
176            failed: 0,
177            skipped,
178            results,
179        };
180    }
181
182    let download_result =
183        download_hashes(&to_download, blobs_path, client, on_progress).await;
184
185    FetchMissingBlobsResult {
186        total: hashes.len(),
187        downloaded: download_result.downloaded,
188        failed: download_result.failed,
189        skipped,
190        results: {
191            let mut combined = results;
192            combined.extend(download_result.results);
193            combined
194        },
195    }
196}
197
198/// Format a [`FetchMissingBlobsResult`] as a human-readable string.
199pub fn format_fetch_result(result: &FetchMissingBlobsResult) -> String {
200    if result.total == 0 {
201        return "All blobs are present locally.".to_string();
202    }
203
204    let mut lines: Vec<String> = Vec::new();
205
206    if result.downloaded > 0 {
207        lines.push(format!("Downloaded {} blob(s)", result.downloaded));
208    }
209
210    if result.failed > 0 {
211        lines.push(format!("Failed to download {} blob(s)", result.failed));
212
213        let failed_results: Vec<&BlobFetchResult> =
214            result.results.iter().filter(|r| !r.success).collect();
215
216        for r in failed_results.iter().take(5) {
217            let short_hash = if r.hash.len() >= 12 {
218                &r.hash[..12]
219            } else {
220                &r.hash
221            };
222            let err = r.error.as_deref().unwrap_or("unknown error");
223            lines.push(format!("  - {}...: {}", short_hash, err));
224        }
225
226        if failed_results.len() > 5 {
227            lines.push(format!("  ... and {} more", failed_results.len() - 5));
228        }
229    }
230
231    lines.join("\n")
232}
233
234// ── Internal helpers ──────────────────────────────────────────────────
235
236/// Download a list of blob hashes sequentially, writing each to
237/// `blobs_path/<hash>`.
238async fn download_hashes(
239    hashes: &[String],
240    blobs_path: &Path,
241    client: &ApiClient,
242    on_progress: Option<&OnProgress>,
243) -> FetchMissingBlobsResult {
244    let total = hashes.len();
245    let mut downloaded: usize = 0;
246    let mut failed: usize = 0;
247    let mut results: Vec<BlobFetchResult> = Vec::with_capacity(total);
248
249    for (i, hash) in hashes.iter().enumerate() {
250        if let Some(ref cb) = on_progress {
251            cb(hash, i + 1, total);
252        }
253
254        match client.fetch_blob(hash).await {
255            Ok(Some(data)) => {
256                // Verify content hash matches expected hash before writing
257                let actual_hash = crate::hash::git_sha256::compute_git_sha256_from_bytes(&data);
258                if actual_hash != *hash {
259                    results.push(BlobFetchResult {
260                        hash: hash.clone(),
261                        success: false,
262                        error: Some(format!(
263                            "Content hash mismatch: expected {}, got {}",
264                            hash, actual_hash
265                        )),
266                    });
267                    failed += 1;
268                    continue;
269                }
270
271                let blob_path: PathBuf = blobs_path.join(hash);
272                match tokio::fs::write(&blob_path, &data).await {
273                    Ok(()) => {
274                        results.push(BlobFetchResult {
275                            hash: hash.clone(),
276                            success: true,
277                            error: None,
278                        });
279                        downloaded += 1;
280                    }
281                    Err(e) => {
282                        results.push(BlobFetchResult {
283                            hash: hash.clone(),
284                            success: false,
285                            error: Some(format!("Failed to write blob to disk: {}", e)),
286                        });
287                        failed += 1;
288                    }
289                }
290            }
291            Ok(None) => {
292                results.push(BlobFetchResult {
293                    hash: hash.clone(),
294                    success: false,
295                    error: Some("Blob not found on server".to_string()),
296                });
297                failed += 1;
298            }
299            Err(e) => {
300                results.push(BlobFetchResult {
301                    hash: hash.clone(),
302                    success: false,
303                    error: Some(e.to_string()),
304                });
305                failed += 1;
306            }
307        }
308    }
309
310    FetchMissingBlobsResult {
311        total,
312        downloaded,
313        failed,
314        skipped: 0,
315        results,
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord};
323    use std::collections::HashMap;
324
325    fn make_manifest_with_hashes(after_hashes: &[&str]) -> PatchManifest {
326        let mut files = HashMap::new();
327        for (i, ah) in after_hashes.iter().enumerate() {
328            files.insert(
329                format!("package/file{}.js", i),
330                PatchFileInfo {
331                    before_hash: format!(
332                        "before{}{}",
333                        "0".repeat(58),
334                        format!("{:06}", i)
335                    ),
336                    after_hash: ah.to_string(),
337                },
338            );
339        }
340
341        let mut patches = HashMap::new();
342        patches.insert(
343            "pkg:npm/test@1.0.0".to_string(),
344            PatchRecord {
345                uuid: "test-uuid".to_string(),
346                exported_at: "2024-01-01T00:00:00Z".to_string(),
347                files,
348                vulnerabilities: HashMap::new(),
349                description: "test".to_string(),
350                license: "MIT".to_string(),
351                tier: "free".to_string(),
352            },
353        );
354
355        PatchManifest { patches }
356    }
357
358    #[tokio::test]
359    async fn test_get_missing_blobs_all_missing() {
360        let dir = tempfile::tempdir().unwrap();
361        let blobs_path = dir.path().join("blobs");
362        tokio::fs::create_dir_all(&blobs_path).await.unwrap();
363
364        let h1 = "a".repeat(64);
365        let h2 = "b".repeat(64);
366        let manifest = make_manifest_with_hashes(&[&h1, &h2]);
367
368        let missing = get_missing_blobs(&manifest, &blobs_path).await;
369        assert_eq!(missing.len(), 2);
370        assert!(missing.contains(&h1));
371        assert!(missing.contains(&h2));
372    }
373
374    #[tokio::test]
375    async fn test_get_missing_blobs_some_present() {
376        let dir = tempfile::tempdir().unwrap();
377        let blobs_path = dir.path().join("blobs");
378        tokio::fs::create_dir_all(&blobs_path).await.unwrap();
379
380        let h1 = "a".repeat(64);
381        let h2 = "b".repeat(64);
382
383        // Write h1 to disk so it is NOT missing
384        tokio::fs::write(blobs_path.join(&h1), b"data").await.unwrap();
385
386        let manifest = make_manifest_with_hashes(&[&h1, &h2]);
387        let missing = get_missing_blobs(&manifest, &blobs_path).await;
388        assert_eq!(missing.len(), 1);
389        assert!(missing.contains(&h2));
390        assert!(!missing.contains(&h1));
391    }
392
393    #[tokio::test]
394    async fn test_get_missing_blobs_empty_manifest() {
395        let dir = tempfile::tempdir().unwrap();
396        let blobs_path = dir.path().join("blobs");
397        tokio::fs::create_dir_all(&blobs_path).await.unwrap();
398
399        let manifest = PatchManifest::new();
400        let missing = get_missing_blobs(&manifest, &blobs_path).await;
401        assert!(missing.is_empty());
402    }
403
404    #[test]
405    fn test_format_fetch_result_all_present() {
406        let result = FetchMissingBlobsResult {
407            total: 0,
408            downloaded: 0,
409            failed: 0,
410            skipped: 0,
411            results: Vec::new(),
412        };
413        assert_eq!(format_fetch_result(&result), "All blobs are present locally.");
414    }
415
416    #[test]
417    fn test_format_fetch_result_some_downloaded() {
418        let result = FetchMissingBlobsResult {
419            total: 3,
420            downloaded: 2,
421            failed: 1,
422            skipped: 0,
423            results: vec![
424                BlobFetchResult {
425                    hash: "a".repeat(64),
426                    success: true,
427                    error: None,
428                },
429                BlobFetchResult {
430                    hash: "b".repeat(64),
431                    success: true,
432                    error: None,
433                },
434                BlobFetchResult {
435                    hash: "c".repeat(64),
436                    success: false,
437                    error: Some("Blob not found on server".to_string()),
438                },
439            ],
440        };
441        let output = format_fetch_result(&result);
442        assert!(output.contains("Downloaded 2 blob(s)"));
443        assert!(output.contains("Failed to download 1 blob(s)"));
444        assert!(output.contains("cccccccccccc..."));
445        assert!(output.contains("Blob not found on server"));
446    }
447
448    #[test]
449    fn test_format_fetch_result_truncates_at_5() {
450        let results: Vec<BlobFetchResult> = (0..8)
451            .map(|i| BlobFetchResult {
452                hash: format!("{:0>64}", i),
453                success: false,
454                error: Some(format!("error {}", i)),
455            })
456            .collect();
457
458        let result = FetchMissingBlobsResult {
459            total: 8,
460            downloaded: 0,
461            failed: 8,
462            skipped: 0,
463            results,
464        };
465        let output = format_fetch_result(&result);
466        assert!(output.contains("... and 3 more"));
467    }
468
469    // ── Group 8: format edge cases ───────────────────────────────────
470
471    #[test]
472    fn test_format_only_downloaded() {
473        let result = FetchMissingBlobsResult {
474            total: 3,
475            downloaded: 3,
476            failed: 0,
477            skipped: 0,
478            results: vec![
479                BlobFetchResult { hash: "a".repeat(64), success: true, error: None },
480                BlobFetchResult { hash: "b".repeat(64), success: true, error: None },
481                BlobFetchResult { hash: "c".repeat(64), success: true, error: None },
482            ],
483        };
484        let output = format_fetch_result(&result);
485        assert!(output.contains("Downloaded 3 blob(s)"));
486        assert!(!output.contains("Failed"));
487    }
488
489    #[test]
490    fn test_format_short_hash() {
491        let result = FetchMissingBlobsResult {
492            total: 1,
493            downloaded: 0,
494            failed: 1,
495            skipped: 0,
496            results: vec![BlobFetchResult {
497                hash: "abc".into(),
498                success: false,
499                error: Some("not found".into()),
500            }],
501        };
502        let output = format_fetch_result(&result);
503        // Hash is < 12 chars, should show full hash
504        assert!(output.contains("abc..."));
505    }
506
507    #[test]
508    fn test_format_error_none() {
509        let result = FetchMissingBlobsResult {
510            total: 1,
511            downloaded: 0,
512            failed: 1,
513            skipped: 0,
514            results: vec![BlobFetchResult {
515                hash: "d".repeat(64),
516                success: false,
517                error: None,
518            }],
519        };
520        let output = format_fetch_result(&result);
521        assert!(output.contains("unknown error"));
522    }
523
524    #[test]
525    fn test_format_only_failed() {
526        let result = FetchMissingBlobsResult {
527            total: 2,
528            downloaded: 0,
529            failed: 2,
530            skipped: 0,
531            results: vec![
532                BlobFetchResult {
533                    hash: "a".repeat(64),
534                    success: false,
535                    error: Some("timeout".into()),
536                },
537                BlobFetchResult {
538                    hash: "b".repeat(64),
539                    success: false,
540                    error: Some("timeout".into()),
541                },
542            ],
543        };
544        let output = format_fetch_result(&result);
545        assert!(!output.contains("Downloaded"));
546        assert!(output.contains("Failed to download 2 blob(s)"));
547    }
548}