socket_patch_core/manifest/
operations.rs1use std::collections::HashSet;
2use std::path::Path;
3
4use crate::manifest::schema::PatchManifest;
5
6pub fn get_referenced_blobs(manifest: &PatchManifest) -> HashSet<String> {
9 let mut blobs = HashSet::new();
10
11 for record in manifest.patches.values() {
12 for file_info in record.files.values() {
13 blobs.insert(file_info.before_hash.clone());
14 blobs.insert(file_info.after_hash.clone());
15 }
16 }
17
18 blobs
19}
20
21pub fn get_after_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
25 let mut blobs = HashSet::new();
26
27 for record in manifest.patches.values() {
28 for file_info in record.files.values() {
29 blobs.insert(file_info.after_hash.clone());
30 }
31 }
32
33 blobs
34}
35
36pub fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
39 let mut blobs = HashSet::new();
40
41 for record in manifest.patches.values() {
42 for file_info in record.files.values() {
43 blobs.insert(file_info.before_hash.clone());
44 }
45 }
46
47 blobs
48}
49
50#[derive(Debug, Clone)]
52pub struct ManifestDiff {
53 pub added: HashSet<String>,
55 pub removed: HashSet<String>,
57 pub modified: HashSet<String>,
59}
60
61pub fn diff_manifests(old_manifest: &PatchManifest, new_manifest: &PatchManifest) -> ManifestDiff {
65 let old_purls: HashSet<&String> = old_manifest.patches.keys().collect();
66 let new_purls: HashSet<&String> = new_manifest.patches.keys().collect();
67
68 let mut added = HashSet::new();
69 let mut removed = HashSet::new();
70 let mut modified = HashSet::new();
71
72 for purl in &new_purls {
74 if !old_purls.contains(purl) {
75 added.insert((*purl).clone());
76 } else {
77 let old_patch = &old_manifest.patches[*purl];
78 let new_patch = &new_manifest.patches[*purl];
79 if old_patch.uuid != new_patch.uuid {
80 modified.insert((*purl).clone());
81 }
82 }
83 }
84
85 for purl in &old_purls {
87 if !new_purls.contains(purl) {
88 removed.insert((*purl).clone());
89 }
90 }
91
92 ManifestDiff {
93 added,
94 removed,
95 modified,
96 }
97}
98
99pub fn validate_manifest(value: &serde_json::Value) -> Result<PatchManifest, String> {
102 serde_json::from_value::<PatchManifest>(value.clone())
103 .map_err(|e| format!("Invalid manifest: {}", e))
104}
105
106pub async fn read_manifest(path: impl AsRef<Path>) -> Result<Option<PatchManifest>, std::io::Error> {
110 let path = path.as_ref();
111
112 let content = match tokio::fs::read_to_string(path).await {
113 Ok(c) => c,
114 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
115 Err(e) => return Err(e), };
117
118 let parsed: serde_json::Value = match serde_json::from_str(&content) {
119 Ok(v) => v,
120 Err(e) => return Err(std::io::Error::new(
121 std::io::ErrorKind::InvalidData,
122 format!("Failed to parse manifest JSON: {}", e),
123 )),
124 };
125
126 match validate_manifest(&parsed) {
127 Ok(manifest) => Ok(Some(manifest)),
128 Err(e) => Err(std::io::Error::new(
129 std::io::ErrorKind::InvalidData,
130 e,
131 )),
132 }
133}
134
135pub async fn write_manifest(
137 path: impl AsRef<Path>,
138 manifest: &PatchManifest,
139) -> Result<(), std::io::Error> {
140 let content = serde_json::to_string_pretty(manifest)
141 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
142 tokio::fs::write(path, content).await
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::manifest::schema::{PatchFileInfo, PatchRecord};
149 use std::collections::HashMap;
150
151 const TEST_UUID_1: &str = "11111111-1111-4111-8111-111111111111";
152 const TEST_UUID_2: &str = "22222222-2222-4222-8222-222222222222";
153
154 const BEFORE_HASH_1: &str =
155 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1111";
156 const AFTER_HASH_1: &str =
157 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb1111";
158 const BEFORE_HASH_2: &str =
159 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc2222";
160 const AFTER_HASH_2: &str =
161 "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd2222";
162 const BEFORE_HASH_3: &str =
163 "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee3333";
164 const AFTER_HASH_3: &str =
165 "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3333";
166
167 fn create_test_manifest() -> PatchManifest {
168 let mut patches = HashMap::new();
169
170 let mut files_a = HashMap::new();
171 files_a.insert(
172 "package/index.js".to_string(),
173 PatchFileInfo {
174 before_hash: BEFORE_HASH_1.to_string(),
175 after_hash: AFTER_HASH_1.to_string(),
176 },
177 );
178 files_a.insert(
179 "package/lib/utils.js".to_string(),
180 PatchFileInfo {
181 before_hash: BEFORE_HASH_2.to_string(),
182 after_hash: AFTER_HASH_2.to_string(),
183 },
184 );
185
186 patches.insert(
187 "pkg:npm/pkg-a@1.0.0".to_string(),
188 PatchRecord {
189 uuid: TEST_UUID_1.to_string(),
190 exported_at: "2024-01-01T00:00:00Z".to_string(),
191 files: files_a,
192 vulnerabilities: HashMap::new(),
193 description: "Test patch 1".to_string(),
194 license: "MIT".to_string(),
195 tier: "free".to_string(),
196 },
197 );
198
199 let mut files_b = HashMap::new();
200 files_b.insert(
201 "package/main.js".to_string(),
202 PatchFileInfo {
203 before_hash: BEFORE_HASH_3.to_string(),
204 after_hash: AFTER_HASH_3.to_string(),
205 },
206 );
207
208 patches.insert(
209 "pkg:npm/pkg-b@2.0.0".to_string(),
210 PatchRecord {
211 uuid: TEST_UUID_2.to_string(),
212 exported_at: "2024-01-01T00:00:00Z".to_string(),
213 files: files_b,
214 vulnerabilities: HashMap::new(),
215 description: "Test patch 2".to_string(),
216 license: "MIT".to_string(),
217 tier: "free".to_string(),
218 },
219 );
220
221 PatchManifest { patches }
222 }
223
224 #[test]
225 fn test_get_referenced_blobs_returns_all() {
226 let manifest = create_test_manifest();
227 let blobs = get_referenced_blobs(&manifest);
228
229 assert_eq!(blobs.len(), 6);
230 assert!(blobs.contains(BEFORE_HASH_1));
231 assert!(blobs.contains(AFTER_HASH_1));
232 assert!(blobs.contains(BEFORE_HASH_2));
233 assert!(blobs.contains(AFTER_HASH_2));
234 assert!(blobs.contains(BEFORE_HASH_3));
235 assert!(blobs.contains(AFTER_HASH_3));
236 }
237
238 #[test]
239 fn test_get_referenced_blobs_empty_manifest() {
240 let manifest = PatchManifest::new();
241 let blobs = get_referenced_blobs(&manifest);
242 assert_eq!(blobs.len(), 0);
243 }
244
245 #[test]
246 fn test_get_referenced_blobs_deduplicates() {
247 let mut files = HashMap::new();
248 files.insert(
249 "package/file1.js".to_string(),
250 PatchFileInfo {
251 before_hash: BEFORE_HASH_1.to_string(),
252 after_hash: AFTER_HASH_1.to_string(),
253 },
254 );
255 files.insert(
256 "package/file2.js".to_string(),
257 PatchFileInfo {
258 before_hash: BEFORE_HASH_1.to_string(), after_hash: AFTER_HASH_2.to_string(),
260 },
261 );
262
263 let mut patches = HashMap::new();
264 patches.insert(
265 "pkg:npm/pkg-a@1.0.0".to_string(),
266 PatchRecord {
267 uuid: TEST_UUID_1.to_string(),
268 exported_at: "2024-01-01T00:00:00Z".to_string(),
269 files,
270 vulnerabilities: HashMap::new(),
271 description: "Test".to_string(),
272 license: "MIT".to_string(),
273 tier: "free".to_string(),
274 },
275 );
276
277 let manifest = PatchManifest { patches };
278 let blobs = get_referenced_blobs(&manifest);
279 assert_eq!(blobs.len(), 3);
281 }
282
283 #[test]
284 fn test_get_after_hash_blobs() {
285 let manifest = create_test_manifest();
286 let blobs = get_after_hash_blobs(&manifest);
287
288 assert_eq!(blobs.len(), 3);
289 assert!(blobs.contains(AFTER_HASH_1));
290 assert!(blobs.contains(AFTER_HASH_2));
291 assert!(blobs.contains(AFTER_HASH_3));
292 assert!(!blobs.contains(BEFORE_HASH_1));
293 assert!(!blobs.contains(BEFORE_HASH_2));
294 assert!(!blobs.contains(BEFORE_HASH_3));
295 }
296
297 #[test]
298 fn test_get_after_hash_blobs_empty() {
299 let manifest = PatchManifest::new();
300 let blobs = get_after_hash_blobs(&manifest);
301 assert_eq!(blobs.len(), 0);
302 }
303
304 #[test]
305 fn test_get_before_hash_blobs() {
306 let manifest = create_test_manifest();
307 let blobs = get_before_hash_blobs(&manifest);
308
309 assert_eq!(blobs.len(), 3);
310 assert!(blobs.contains(BEFORE_HASH_1));
311 assert!(blobs.contains(BEFORE_HASH_2));
312 assert!(blobs.contains(BEFORE_HASH_3));
313 assert!(!blobs.contains(AFTER_HASH_1));
314 assert!(!blobs.contains(AFTER_HASH_2));
315 assert!(!blobs.contains(AFTER_HASH_3));
316 }
317
318 #[test]
319 fn test_get_before_hash_blobs_empty() {
320 let manifest = PatchManifest::new();
321 let blobs = get_before_hash_blobs(&manifest);
322 assert_eq!(blobs.len(), 0);
323 }
324
325 #[test]
326 fn test_after_plus_before_equals_all() {
327 let manifest = create_test_manifest();
328 let all_blobs = get_referenced_blobs(&manifest);
329 let after_blobs = get_after_hash_blobs(&manifest);
330 let before_blobs = get_before_hash_blobs(&manifest);
331
332 let union: HashSet<String> = after_blobs.union(&before_blobs).cloned().collect();
333 assert_eq!(union.len(), all_blobs.len());
334 for blob in &all_blobs {
335 assert!(union.contains(blob));
336 }
337 }
338
339 #[test]
340 fn test_diff_manifests_added() {
341 let old = PatchManifest::new();
342 let new_manifest = create_test_manifest();
343
344 let diff = diff_manifests(&old, &new_manifest);
345 assert_eq!(diff.added.len(), 2);
346 assert!(diff.added.contains("pkg:npm/pkg-a@1.0.0"));
347 assert!(diff.added.contains("pkg:npm/pkg-b@2.0.0"));
348 assert_eq!(diff.removed.len(), 0);
349 assert_eq!(diff.modified.len(), 0);
350 }
351
352 #[test]
353 fn test_diff_manifests_removed() {
354 let old = create_test_manifest();
355 let new_manifest = PatchManifest::new();
356
357 let diff = diff_manifests(&old, &new_manifest);
358 assert_eq!(diff.added.len(), 0);
359 assert_eq!(diff.removed.len(), 2);
360 assert!(diff.removed.contains("pkg:npm/pkg-a@1.0.0"));
361 assert!(diff.removed.contains("pkg:npm/pkg-b@2.0.0"));
362 assert_eq!(diff.modified.len(), 0);
363 }
364
365 #[test]
366 fn test_diff_manifests_modified() {
367 let old = create_test_manifest();
368 let mut new_manifest = create_test_manifest();
369 new_manifest
371 .patches
372 .get_mut("pkg:npm/pkg-a@1.0.0")
373 .unwrap()
374 .uuid = "33333333-3333-4333-8333-333333333333".to_string();
375
376 let diff = diff_manifests(&old, &new_manifest);
377 assert_eq!(diff.added.len(), 0);
378 assert_eq!(diff.removed.len(), 0);
379 assert_eq!(diff.modified.len(), 1);
380 assert!(diff.modified.contains("pkg:npm/pkg-a@1.0.0"));
381 }
382
383 #[test]
384 fn test_diff_manifests_same() {
385 let old = create_test_manifest();
386 let new_manifest = create_test_manifest();
387
388 let diff = diff_manifests(&old, &new_manifest);
389 assert_eq!(diff.added.len(), 0);
390 assert_eq!(diff.removed.len(), 0);
391 assert_eq!(diff.modified.len(), 0);
392 }
393
394 #[test]
395 fn test_validate_manifest_valid() {
396 let json = serde_json::json!({
397 "patches": {
398 "pkg:npm/test@1.0.0": {
399 "uuid": "11111111-1111-4111-8111-111111111111",
400 "exportedAt": "2024-01-01T00:00:00Z",
401 "files": {},
402 "vulnerabilities": {},
403 "description": "test",
404 "license": "MIT",
405 "tier": "free"
406 }
407 }
408 });
409
410 let result = validate_manifest(&json);
411 assert!(result.is_ok());
412 let manifest = result.unwrap();
413 assert_eq!(manifest.patches.len(), 1);
414 }
415
416 #[test]
417 fn test_validate_manifest_invalid() {
418 let json = serde_json::json!({
419 "patches": "not-an-object"
420 });
421
422 let result = validate_manifest(&json);
423 assert!(result.is_err());
424 }
425
426 #[test]
427 fn test_validate_manifest_missing_fields() {
428 let json = serde_json::json!({
429 "patches": {
430 "pkg:npm/test@1.0.0": {
431 "uuid": "test"
432 }
433 }
434 });
435
436 let result = validate_manifest(&json);
437 assert!(result.is_err());
438 }
439
440 #[tokio::test]
441 async fn test_read_manifest_not_found() {
442 let result = read_manifest("/nonexistent/path/manifest.json").await;
443 assert!(result.is_ok());
444 assert!(result.unwrap().is_none());
445 }
446
447 #[tokio::test]
448 async fn test_write_and_read_manifest() {
449 let dir = tempfile::tempdir().unwrap();
450 let path = dir.path().join("manifest.json");
451
452 let manifest = create_test_manifest();
453 write_manifest(&path, &manifest).await.unwrap();
454
455 let read_back = read_manifest(&path).await.unwrap();
456 assert!(read_back.is_some());
457 let read_back = read_back.unwrap();
458 assert_eq!(read_back.patches.len(), 2);
459 }
460}