Skip to main content

socket_patch_core/manifest/
operations.rs

1use std::collections::HashSet;
2use std::path::Path;
3
4use crate::manifest::schema::PatchManifest;
5
6/// Get all blob hashes referenced by a manifest (both beforeHash and afterHash).
7/// Used for garbage collection and validation.
8pub fn get_referenced_blobs(manifest: &PatchManifest) -> HashSet<String> {
9    let mut blobs = HashSet::new();
10
11    for record in manifest.patches.values() {
12        for file_info in record.files.values() {
13            blobs.insert(file_info.before_hash.clone());
14            blobs.insert(file_info.after_hash.clone());
15        }
16    }
17
18    blobs
19}
20
21/// Get only afterHash blobs referenced by a manifest.
22/// Used for apply operations -- we only need the patched file content, not the original.
23/// This saves disk space since beforeHash blobs are not needed for applying patches.
24pub fn get_after_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
25    let mut blobs = HashSet::new();
26
27    for record in manifest.patches.values() {
28        for file_info in record.files.values() {
29            blobs.insert(file_info.after_hash.clone());
30        }
31    }
32
33    blobs
34}
35
36/// Get only beforeHash blobs referenced by a manifest.
37/// Used for rollback operations -- we need the original file content to restore.
38pub fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
39    let mut blobs = HashSet::new();
40
41    for record in manifest.patches.values() {
42        for file_info in record.files.values() {
43            blobs.insert(file_info.before_hash.clone());
44        }
45    }
46
47    blobs
48}
49
50/// Differences between two manifests.
51#[derive(Debug, Clone)]
52pub struct ManifestDiff {
53    /// PURLs present in new but not old.
54    pub added: HashSet<String>,
55    /// PURLs present in old but not new.
56    pub removed: HashSet<String>,
57    /// PURLs present in both but with different UUIDs.
58    pub modified: HashSet<String>,
59}
60
61/// Calculate differences between two manifests.
62/// Patches are compared by UUID: if the PURL exists in both manifests but the
63/// UUID changed, the patch is considered modified.
64pub fn diff_manifests(old_manifest: &PatchManifest, new_manifest: &PatchManifest) -> ManifestDiff {
65    let old_purls: HashSet<&String> = old_manifest.patches.keys().collect();
66    let new_purls: HashSet<&String> = new_manifest.patches.keys().collect();
67
68    let mut added = HashSet::new();
69    let mut removed = HashSet::new();
70    let mut modified = HashSet::new();
71
72    // Find added and modified
73    for purl in &new_purls {
74        if !old_purls.contains(purl) {
75            added.insert((*purl).clone());
76        } else {
77            let old_patch = &old_manifest.patches[*purl];
78            let new_patch = &new_manifest.patches[*purl];
79            if old_patch.uuid != new_patch.uuid {
80                modified.insert((*purl).clone());
81            }
82        }
83    }
84
85    // Find removed
86    for purl in &old_purls {
87        if !new_purls.contains(purl) {
88            removed.insert((*purl).clone());
89        }
90    }
91
92    ManifestDiff {
93        added,
94        removed,
95        modified,
96    }
97}
98
99/// Validate a parsed JSON value as a PatchManifest.
100/// Returns Ok(manifest) if valid, or Err(message) if invalid.
101pub fn validate_manifest(value: &serde_json::Value) -> Result<PatchManifest, String> {
102    serde_json::from_value::<PatchManifest>(value.clone())
103        .map_err(|e| format!("Invalid manifest: {}", e))
104}
105
106/// Read and parse a manifest from the filesystem.
107/// Returns Ok(None) if the file does not exist.
108/// Returns Err for I/O errors, JSON parse errors, or validation errors.
109pub async fn read_manifest(path: impl AsRef<Path>) -> Result<Option<PatchManifest>, std::io::Error> {
110    let path = path.as_ref();
111
112    let content = match tokio::fs::read_to_string(path).await {
113        Ok(c) => c,
114        Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
115        Err(e) => return Err(e),   // FIX: propagate actual I/O error
116    };
117
118    let parsed: serde_json::Value = match serde_json::from_str(&content) {
119        Ok(v) => v,
120        Err(e) => return Err(std::io::Error::new(
121            std::io::ErrorKind::InvalidData,
122            format!("Failed to parse manifest JSON: {}", e),
123        )),
124    };
125
126    match validate_manifest(&parsed) {
127        Ok(manifest) => Ok(Some(manifest)),
128        Err(e) => Err(std::io::Error::new(
129            std::io::ErrorKind::InvalidData,
130            e,
131        )),
132    }
133}
134
135/// Write a manifest to the filesystem with pretty-printed JSON.
136pub async fn write_manifest(
137    path: impl AsRef<Path>,
138    manifest: &PatchManifest,
139) -> Result<(), std::io::Error> {
140    let content = serde_json::to_string_pretty(manifest)
141        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
142    tokio::fs::write(path, content).await
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::manifest::schema::{PatchFileInfo, PatchRecord};
149    use std::collections::HashMap;
150
151    const TEST_UUID_1: &str = "11111111-1111-4111-8111-111111111111";
152    const TEST_UUID_2: &str = "22222222-2222-4222-8222-222222222222";
153
154    const BEFORE_HASH_1: &str =
155        "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1111";
156    const AFTER_HASH_1: &str =
157        "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb1111";
158    const BEFORE_HASH_2: &str =
159        "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc2222";
160    const AFTER_HASH_2: &str =
161        "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd2222";
162    const BEFORE_HASH_3: &str =
163        "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee3333";
164    const AFTER_HASH_3: &str =
165        "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3333";
166
167    fn create_test_manifest() -> PatchManifest {
168        let mut patches = HashMap::new();
169
170        let mut files_a = HashMap::new();
171        files_a.insert(
172            "package/index.js".to_string(),
173            PatchFileInfo {
174                before_hash: BEFORE_HASH_1.to_string(),
175                after_hash: AFTER_HASH_1.to_string(),
176            },
177        );
178        files_a.insert(
179            "package/lib/utils.js".to_string(),
180            PatchFileInfo {
181                before_hash: BEFORE_HASH_2.to_string(),
182                after_hash: AFTER_HASH_2.to_string(),
183            },
184        );
185
186        patches.insert(
187            "pkg:npm/pkg-a@1.0.0".to_string(),
188            PatchRecord {
189                uuid: TEST_UUID_1.to_string(),
190                exported_at: "2024-01-01T00:00:00Z".to_string(),
191                files: files_a,
192                vulnerabilities: HashMap::new(),
193                description: "Test patch 1".to_string(),
194                license: "MIT".to_string(),
195                tier: "free".to_string(),
196            },
197        );
198
199        let mut files_b = HashMap::new();
200        files_b.insert(
201            "package/main.js".to_string(),
202            PatchFileInfo {
203                before_hash: BEFORE_HASH_3.to_string(),
204                after_hash: AFTER_HASH_3.to_string(),
205            },
206        );
207
208        patches.insert(
209            "pkg:npm/pkg-b@2.0.0".to_string(),
210            PatchRecord {
211                uuid: TEST_UUID_2.to_string(),
212                exported_at: "2024-01-01T00:00:00Z".to_string(),
213                files: files_b,
214                vulnerabilities: HashMap::new(),
215                description: "Test patch 2".to_string(),
216                license: "MIT".to_string(),
217                tier: "free".to_string(),
218            },
219        );
220
221        PatchManifest { patches }
222    }
223
224    #[test]
225    fn test_get_referenced_blobs_returns_all() {
226        let manifest = create_test_manifest();
227        let blobs = get_referenced_blobs(&manifest);
228
229        assert_eq!(blobs.len(), 6);
230        assert!(blobs.contains(BEFORE_HASH_1));
231        assert!(blobs.contains(AFTER_HASH_1));
232        assert!(blobs.contains(BEFORE_HASH_2));
233        assert!(blobs.contains(AFTER_HASH_2));
234        assert!(blobs.contains(BEFORE_HASH_3));
235        assert!(blobs.contains(AFTER_HASH_3));
236    }
237
238    #[test]
239    fn test_get_referenced_blobs_empty_manifest() {
240        let manifest = PatchManifest::new();
241        let blobs = get_referenced_blobs(&manifest);
242        assert_eq!(blobs.len(), 0);
243    }
244
245    #[test]
246    fn test_get_referenced_blobs_deduplicates() {
247        let mut files = HashMap::new();
248        files.insert(
249            "package/file1.js".to_string(),
250            PatchFileInfo {
251                before_hash: BEFORE_HASH_1.to_string(),
252                after_hash: AFTER_HASH_1.to_string(),
253            },
254        );
255        files.insert(
256            "package/file2.js".to_string(),
257            PatchFileInfo {
258                before_hash: BEFORE_HASH_1.to_string(), // same as file1
259                after_hash: AFTER_HASH_2.to_string(),
260            },
261        );
262
263        let mut patches = HashMap::new();
264        patches.insert(
265            "pkg:npm/pkg-a@1.0.0".to_string(),
266            PatchRecord {
267                uuid: TEST_UUID_1.to_string(),
268                exported_at: "2024-01-01T00:00:00Z".to_string(),
269                files,
270                vulnerabilities: HashMap::new(),
271                description: "Test".to_string(),
272                license: "MIT".to_string(),
273                tier: "free".to_string(),
274            },
275        );
276
277        let manifest = PatchManifest { patches };
278        let blobs = get_referenced_blobs(&manifest);
279        // 3 unique hashes, not 4
280        assert_eq!(blobs.len(), 3);
281    }
282
283    #[test]
284    fn test_get_after_hash_blobs() {
285        let manifest = create_test_manifest();
286        let blobs = get_after_hash_blobs(&manifest);
287
288        assert_eq!(blobs.len(), 3);
289        assert!(blobs.contains(AFTER_HASH_1));
290        assert!(blobs.contains(AFTER_HASH_2));
291        assert!(blobs.contains(AFTER_HASH_3));
292        assert!(!blobs.contains(BEFORE_HASH_1));
293        assert!(!blobs.contains(BEFORE_HASH_2));
294        assert!(!blobs.contains(BEFORE_HASH_3));
295    }
296
297    #[test]
298    fn test_get_after_hash_blobs_empty() {
299        let manifest = PatchManifest::new();
300        let blobs = get_after_hash_blobs(&manifest);
301        assert_eq!(blobs.len(), 0);
302    }
303
304    #[test]
305    fn test_get_before_hash_blobs() {
306        let manifest = create_test_manifest();
307        let blobs = get_before_hash_blobs(&manifest);
308
309        assert_eq!(blobs.len(), 3);
310        assert!(blobs.contains(BEFORE_HASH_1));
311        assert!(blobs.contains(BEFORE_HASH_2));
312        assert!(blobs.contains(BEFORE_HASH_3));
313        assert!(!blobs.contains(AFTER_HASH_1));
314        assert!(!blobs.contains(AFTER_HASH_2));
315        assert!(!blobs.contains(AFTER_HASH_3));
316    }
317
318    #[test]
319    fn test_get_before_hash_blobs_empty() {
320        let manifest = PatchManifest::new();
321        let blobs = get_before_hash_blobs(&manifest);
322        assert_eq!(blobs.len(), 0);
323    }
324
325    #[test]
326    fn test_after_plus_before_equals_all() {
327        let manifest = create_test_manifest();
328        let all_blobs = get_referenced_blobs(&manifest);
329        let after_blobs = get_after_hash_blobs(&manifest);
330        let before_blobs = get_before_hash_blobs(&manifest);
331
332        let union: HashSet<String> = after_blobs.union(&before_blobs).cloned().collect();
333        assert_eq!(union.len(), all_blobs.len());
334        for blob in &all_blobs {
335            assert!(union.contains(blob));
336        }
337    }
338
339    #[test]
340    fn test_diff_manifests_added() {
341        let old = PatchManifest::new();
342        let new_manifest = create_test_manifest();
343
344        let diff = diff_manifests(&old, &new_manifest);
345        assert_eq!(diff.added.len(), 2);
346        assert!(diff.added.contains("pkg:npm/pkg-a@1.0.0"));
347        assert!(diff.added.contains("pkg:npm/pkg-b@2.0.0"));
348        assert_eq!(diff.removed.len(), 0);
349        assert_eq!(diff.modified.len(), 0);
350    }
351
352    #[test]
353    fn test_diff_manifests_removed() {
354        let old = create_test_manifest();
355        let new_manifest = PatchManifest::new();
356
357        let diff = diff_manifests(&old, &new_manifest);
358        assert_eq!(diff.added.len(), 0);
359        assert_eq!(diff.removed.len(), 2);
360        assert!(diff.removed.contains("pkg:npm/pkg-a@1.0.0"));
361        assert!(diff.removed.contains("pkg:npm/pkg-b@2.0.0"));
362        assert_eq!(diff.modified.len(), 0);
363    }
364
365    #[test]
366    fn test_diff_manifests_modified() {
367        let old = create_test_manifest();
368        let mut new_manifest = create_test_manifest();
369        // Change UUID of pkg-a
370        new_manifest
371            .patches
372            .get_mut("pkg:npm/pkg-a@1.0.0")
373            .unwrap()
374            .uuid = "33333333-3333-4333-8333-333333333333".to_string();
375
376        let diff = diff_manifests(&old, &new_manifest);
377        assert_eq!(diff.added.len(), 0);
378        assert_eq!(diff.removed.len(), 0);
379        assert_eq!(diff.modified.len(), 1);
380        assert!(diff.modified.contains("pkg:npm/pkg-a@1.0.0"));
381    }
382
383    #[test]
384    fn test_diff_manifests_same() {
385        let old = create_test_manifest();
386        let new_manifest = create_test_manifest();
387
388        let diff = diff_manifests(&old, &new_manifest);
389        assert_eq!(diff.added.len(), 0);
390        assert_eq!(diff.removed.len(), 0);
391        assert_eq!(diff.modified.len(), 0);
392    }
393
394    #[test]
395    fn test_validate_manifest_valid() {
396        let json = serde_json::json!({
397            "patches": {
398                "pkg:npm/test@1.0.0": {
399                    "uuid": "11111111-1111-4111-8111-111111111111",
400                    "exportedAt": "2024-01-01T00:00:00Z",
401                    "files": {},
402                    "vulnerabilities": {},
403                    "description": "test",
404                    "license": "MIT",
405                    "tier": "free"
406                }
407            }
408        });
409
410        let result = validate_manifest(&json);
411        assert!(result.is_ok());
412        let manifest = result.unwrap();
413        assert_eq!(manifest.patches.len(), 1);
414    }
415
416    #[test]
417    fn test_validate_manifest_invalid() {
418        let json = serde_json::json!({
419            "patches": "not-an-object"
420        });
421
422        let result = validate_manifest(&json);
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_validate_manifest_missing_fields() {
428        let json = serde_json::json!({
429            "patches": {
430                "pkg:npm/test@1.0.0": {
431                    "uuid": "test"
432                }
433            }
434        });
435
436        let result = validate_manifest(&json);
437        assert!(result.is_err());
438    }
439
440    #[tokio::test]
441    async fn test_read_manifest_not_found() {
442        let result = read_manifest("/nonexistent/path/manifest.json").await;
443        assert!(result.is_ok());
444        assert!(result.unwrap().is_none());
445    }
446
447    #[tokio::test]
448    async fn test_write_and_read_manifest() {
449        let dir = tempfile::tempdir().unwrap();
450        let path = dir.path().join("manifest.json");
451
452        let manifest = create_test_manifest();
453        write_manifest(&path, &manifest).await.unwrap();
454
455        let read_back = read_manifest(&path).await.unwrap();
456        assert!(read_back.is_some());
457        let read_back = read_back.unwrap();
458        assert_eq!(read_back.patches.len(), 2);
459    }
460}