Skip to main content

socket_patch_core/manifest/
operations.rs

1use std::collections::HashSet;
2use std::path::Path;
3
4use crate::manifest::schema::PatchManifest;
5
6/// Get all blob hashes referenced by a manifest (both beforeHash and afterHash).
7/// Used for garbage collection and validation.
8pub fn get_referenced_blobs(manifest: &PatchManifest) -> HashSet<String> {
9    let mut blobs = HashSet::new();
10
11    for record in manifest.patches.values() {
12        for file_info in record.files.values() {
13            blobs.insert(file_info.before_hash.clone());
14            blobs.insert(file_info.after_hash.clone());
15        }
16    }
17
18    blobs
19}
20
21/// Get only afterHash blobs referenced by a manifest.
22/// Used for apply operations -- we only need the patched file content, not the original.
23/// This saves disk space since beforeHash blobs are not needed for applying patches.
24pub fn get_after_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
25    let mut blobs = HashSet::new();
26
27    for record in manifest.patches.values() {
28        for file_info in record.files.values() {
29            blobs.insert(file_info.after_hash.clone());
30        }
31    }
32
33    blobs
34}
35
36/// Get only beforeHash blobs referenced by a manifest.
37/// Used for rollback operations -- we need the original file content to restore.
38pub fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
39    let mut blobs = HashSet::new();
40
41    for record in manifest.patches.values() {
42        for file_info in record.files.values() {
43            blobs.insert(file_info.before_hash.clone());
44        }
45    }
46
47    blobs
48}
49
50/// Differences between two manifests.
51#[derive(Debug, Clone)]
52pub struct ManifestDiff {
53    /// PURLs present in new but not old.
54    pub added: HashSet<String>,
55    /// PURLs present in old but not new.
56    pub removed: HashSet<String>,
57    /// PURLs present in both but with different UUIDs.
58    pub modified: HashSet<String>,
59}
60
61/// Calculate differences between two manifests.
62/// Patches are compared by UUID: if the PURL exists in both manifests but the
63/// UUID changed, the patch is considered modified.
64pub fn diff_manifests(old_manifest: &PatchManifest, new_manifest: &PatchManifest) -> ManifestDiff {
65    let old_purls: HashSet<&String> = old_manifest.patches.keys().collect();
66    let new_purls: HashSet<&String> = new_manifest.patches.keys().collect();
67
68    let mut added = HashSet::new();
69    let mut removed = HashSet::new();
70    let mut modified = HashSet::new();
71
72    // Find added and modified
73    for purl in &new_purls {
74        if !old_purls.contains(purl) {
75            added.insert((*purl).clone());
76        } else {
77            let old_patch = &old_manifest.patches[*purl];
78            let new_patch = &new_manifest.patches[*purl];
79            if old_patch.uuid != new_patch.uuid {
80                modified.insert((*purl).clone());
81            }
82        }
83    }
84
85    // Find removed
86    for purl in &old_purls {
87        if !new_purls.contains(purl) {
88            removed.insert((*purl).clone());
89        }
90    }
91
92    ManifestDiff {
93        added,
94        removed,
95        modified,
96    }
97}
98
99/// Validate a parsed JSON value as a PatchManifest.
100/// Returns Ok(manifest) if valid, or Err(message) if invalid.
101pub fn validate_manifest(value: &serde_json::Value) -> Result<PatchManifest, String> {
102    serde_json::from_value::<PatchManifest>(value.clone())
103        .map_err(|e| format!("Invalid manifest: {}", e))
104}
105
106/// Read and parse a manifest from the filesystem.
107/// Returns Ok(None) if the file does not exist or cannot be parsed.
108pub async fn read_manifest(path: impl AsRef<Path>) -> Result<Option<PatchManifest>, std::io::Error> {
109    let path = path.as_ref();
110
111    let content = match tokio::fs::read_to_string(path).await {
112        Ok(c) => c,
113        Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
114        Err(_) => return Ok(None),
115    };
116
117    let parsed: serde_json::Value = match serde_json::from_str(&content) {
118        Ok(v) => v,
119        Err(_) => return Ok(None),
120    };
121
122    match validate_manifest(&parsed) {
123        Ok(manifest) => Ok(Some(manifest)),
124        Err(_) => Ok(None),
125    }
126}
127
128/// Write a manifest to the filesystem with pretty-printed JSON.
129pub async fn write_manifest(
130    path: impl AsRef<Path>,
131    manifest: &PatchManifest,
132) -> Result<(), std::io::Error> {
133    let content = serde_json::to_string_pretty(manifest)
134        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
135    tokio::fs::write(path, content).await
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::manifest::schema::{PatchFileInfo, PatchRecord};
142    use std::collections::HashMap;
143
144    const TEST_UUID_1: &str = "11111111-1111-4111-8111-111111111111";
145    const TEST_UUID_2: &str = "22222222-2222-4222-8222-222222222222";
146
147    const BEFORE_HASH_1: &str =
148        "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1111";
149    const AFTER_HASH_1: &str =
150        "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb1111";
151    const BEFORE_HASH_2: &str =
152        "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc2222";
153    const AFTER_HASH_2: &str =
154        "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd2222";
155    const BEFORE_HASH_3: &str =
156        "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee3333";
157    const AFTER_HASH_3: &str =
158        "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3333";
159
160    fn create_test_manifest() -> PatchManifest {
161        let mut patches = HashMap::new();
162
163        let mut files_a = HashMap::new();
164        files_a.insert(
165            "package/index.js".to_string(),
166            PatchFileInfo {
167                before_hash: BEFORE_HASH_1.to_string(),
168                after_hash: AFTER_HASH_1.to_string(),
169            },
170        );
171        files_a.insert(
172            "package/lib/utils.js".to_string(),
173            PatchFileInfo {
174                before_hash: BEFORE_HASH_2.to_string(),
175                after_hash: AFTER_HASH_2.to_string(),
176            },
177        );
178
179        patches.insert(
180            "pkg:npm/pkg-a@1.0.0".to_string(),
181            PatchRecord {
182                uuid: TEST_UUID_1.to_string(),
183                exported_at: "2024-01-01T00:00:00Z".to_string(),
184                files: files_a,
185                vulnerabilities: HashMap::new(),
186                description: "Test patch 1".to_string(),
187                license: "MIT".to_string(),
188                tier: "free".to_string(),
189            },
190        );
191
192        let mut files_b = HashMap::new();
193        files_b.insert(
194            "package/main.js".to_string(),
195            PatchFileInfo {
196                before_hash: BEFORE_HASH_3.to_string(),
197                after_hash: AFTER_HASH_3.to_string(),
198            },
199        );
200
201        patches.insert(
202            "pkg:npm/pkg-b@2.0.0".to_string(),
203            PatchRecord {
204                uuid: TEST_UUID_2.to_string(),
205                exported_at: "2024-01-01T00:00:00Z".to_string(),
206                files: files_b,
207                vulnerabilities: HashMap::new(),
208                description: "Test patch 2".to_string(),
209                license: "MIT".to_string(),
210                tier: "free".to_string(),
211            },
212        );
213
214        PatchManifest { patches }
215    }
216
217    #[test]
218    fn test_get_referenced_blobs_returns_all() {
219        let manifest = create_test_manifest();
220        let blobs = get_referenced_blobs(&manifest);
221
222        assert_eq!(blobs.len(), 6);
223        assert!(blobs.contains(BEFORE_HASH_1));
224        assert!(blobs.contains(AFTER_HASH_1));
225        assert!(blobs.contains(BEFORE_HASH_2));
226        assert!(blobs.contains(AFTER_HASH_2));
227        assert!(blobs.contains(BEFORE_HASH_3));
228        assert!(blobs.contains(AFTER_HASH_3));
229    }
230
231    #[test]
232    fn test_get_referenced_blobs_empty_manifest() {
233        let manifest = PatchManifest::new();
234        let blobs = get_referenced_blobs(&manifest);
235        assert_eq!(blobs.len(), 0);
236    }
237
238    #[test]
239    fn test_get_referenced_blobs_deduplicates() {
240        let mut files = HashMap::new();
241        files.insert(
242            "package/file1.js".to_string(),
243            PatchFileInfo {
244                before_hash: BEFORE_HASH_1.to_string(),
245                after_hash: AFTER_HASH_1.to_string(),
246            },
247        );
248        files.insert(
249            "package/file2.js".to_string(),
250            PatchFileInfo {
251                before_hash: BEFORE_HASH_1.to_string(), // same as file1
252                after_hash: AFTER_HASH_2.to_string(),
253            },
254        );
255
256        let mut patches = HashMap::new();
257        patches.insert(
258            "pkg:npm/pkg-a@1.0.0".to_string(),
259            PatchRecord {
260                uuid: TEST_UUID_1.to_string(),
261                exported_at: "2024-01-01T00:00:00Z".to_string(),
262                files,
263                vulnerabilities: HashMap::new(),
264                description: "Test".to_string(),
265                license: "MIT".to_string(),
266                tier: "free".to_string(),
267            },
268        );
269
270        let manifest = PatchManifest { patches };
271        let blobs = get_referenced_blobs(&manifest);
272        // 3 unique hashes, not 4
273        assert_eq!(blobs.len(), 3);
274    }
275
276    #[test]
277    fn test_get_after_hash_blobs() {
278        let manifest = create_test_manifest();
279        let blobs = get_after_hash_blobs(&manifest);
280
281        assert_eq!(blobs.len(), 3);
282        assert!(blobs.contains(AFTER_HASH_1));
283        assert!(blobs.contains(AFTER_HASH_2));
284        assert!(blobs.contains(AFTER_HASH_3));
285        assert!(!blobs.contains(BEFORE_HASH_1));
286        assert!(!blobs.contains(BEFORE_HASH_2));
287        assert!(!blobs.contains(BEFORE_HASH_3));
288    }
289
290    #[test]
291    fn test_get_after_hash_blobs_empty() {
292        let manifest = PatchManifest::new();
293        let blobs = get_after_hash_blobs(&manifest);
294        assert_eq!(blobs.len(), 0);
295    }
296
297    #[test]
298    fn test_get_before_hash_blobs() {
299        let manifest = create_test_manifest();
300        let blobs = get_before_hash_blobs(&manifest);
301
302        assert_eq!(blobs.len(), 3);
303        assert!(blobs.contains(BEFORE_HASH_1));
304        assert!(blobs.contains(BEFORE_HASH_2));
305        assert!(blobs.contains(BEFORE_HASH_3));
306        assert!(!blobs.contains(AFTER_HASH_1));
307        assert!(!blobs.contains(AFTER_HASH_2));
308        assert!(!blobs.contains(AFTER_HASH_3));
309    }
310
311    #[test]
312    fn test_get_before_hash_blobs_empty() {
313        let manifest = PatchManifest::new();
314        let blobs = get_before_hash_blobs(&manifest);
315        assert_eq!(blobs.len(), 0);
316    }
317
318    #[test]
319    fn test_after_plus_before_equals_all() {
320        let manifest = create_test_manifest();
321        let all_blobs = get_referenced_blobs(&manifest);
322        let after_blobs = get_after_hash_blobs(&manifest);
323        let before_blobs = get_before_hash_blobs(&manifest);
324
325        let union: HashSet<String> = after_blobs.union(&before_blobs).cloned().collect();
326        assert_eq!(union.len(), all_blobs.len());
327        for blob in &all_blobs {
328            assert!(union.contains(blob));
329        }
330    }
331
332    #[test]
333    fn test_diff_manifests_added() {
334        let old = PatchManifest::new();
335        let new_manifest = create_test_manifest();
336
337        let diff = diff_manifests(&old, &new_manifest);
338        assert_eq!(diff.added.len(), 2);
339        assert!(diff.added.contains("pkg:npm/pkg-a@1.0.0"));
340        assert!(diff.added.contains("pkg:npm/pkg-b@2.0.0"));
341        assert_eq!(diff.removed.len(), 0);
342        assert_eq!(diff.modified.len(), 0);
343    }
344
345    #[test]
346    fn test_diff_manifests_removed() {
347        let old = create_test_manifest();
348        let new_manifest = PatchManifest::new();
349
350        let diff = diff_manifests(&old, &new_manifest);
351        assert_eq!(diff.added.len(), 0);
352        assert_eq!(diff.removed.len(), 2);
353        assert!(diff.removed.contains("pkg:npm/pkg-a@1.0.0"));
354        assert!(diff.removed.contains("pkg:npm/pkg-b@2.0.0"));
355        assert_eq!(diff.modified.len(), 0);
356    }
357
358    #[test]
359    fn test_diff_manifests_modified() {
360        let old = create_test_manifest();
361        let mut new_manifest = create_test_manifest();
362        // Change UUID of pkg-a
363        new_manifest
364            .patches
365            .get_mut("pkg:npm/pkg-a@1.0.0")
366            .unwrap()
367            .uuid = "33333333-3333-4333-8333-333333333333".to_string();
368
369        let diff = diff_manifests(&old, &new_manifest);
370        assert_eq!(diff.added.len(), 0);
371        assert_eq!(diff.removed.len(), 0);
372        assert_eq!(diff.modified.len(), 1);
373        assert!(diff.modified.contains("pkg:npm/pkg-a@1.0.0"));
374    }
375
376    #[test]
377    fn test_diff_manifests_same() {
378        let old = create_test_manifest();
379        let new_manifest = create_test_manifest();
380
381        let diff = diff_manifests(&old, &new_manifest);
382        assert_eq!(diff.added.len(), 0);
383        assert_eq!(diff.removed.len(), 0);
384        assert_eq!(diff.modified.len(), 0);
385    }
386
387    #[test]
388    fn test_validate_manifest_valid() {
389        let json = serde_json::json!({
390            "patches": {
391                "pkg:npm/test@1.0.0": {
392                    "uuid": "11111111-1111-4111-8111-111111111111",
393                    "exportedAt": "2024-01-01T00:00:00Z",
394                    "files": {},
395                    "vulnerabilities": {},
396                    "description": "test",
397                    "license": "MIT",
398                    "tier": "free"
399                }
400            }
401        });
402
403        let result = validate_manifest(&json);
404        assert!(result.is_ok());
405        let manifest = result.unwrap();
406        assert_eq!(manifest.patches.len(), 1);
407    }
408
409    #[test]
410    fn test_validate_manifest_invalid() {
411        let json = serde_json::json!({
412            "patches": "not-an-object"
413        });
414
415        let result = validate_manifest(&json);
416        assert!(result.is_err());
417    }
418
419    #[test]
420    fn test_validate_manifest_missing_fields() {
421        let json = serde_json::json!({
422            "patches": {
423                "pkg:npm/test@1.0.0": {
424                    "uuid": "test"
425                }
426            }
427        });
428
429        let result = validate_manifest(&json);
430        assert!(result.is_err());
431    }
432
433    #[tokio::test]
434    async fn test_read_manifest_not_found() {
435        let result = read_manifest("/nonexistent/path/manifest.json").await;
436        assert!(result.is_ok());
437        assert!(result.unwrap().is_none());
438    }
439
440    #[tokio::test]
441    async fn test_write_and_read_manifest() {
442        let dir = tempfile::tempdir().unwrap();
443        let path = dir.path().join("manifest.json");
444
445        let manifest = create_test_manifest();
446        write_manifest(&path, &manifest).await.unwrap();
447
448        let read_back = read_manifest(&path).await.unwrap();
449        assert!(read_back.is_some());
450        let read_back = read_back.unwrap();
451        assert_eq!(read_back.patches.len(), 2);
452    }
453}