Skip to main content

socket_patch_core/manifest/
recovery.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4
5use crate::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo};
6
7/// Result of manifest recovery operation.
8#[derive(Debug, Clone)]
9pub struct RecoveryResult {
10    pub manifest: PatchManifest,
11    pub repair_needed: bool,
12    pub invalid_patches: Vec<String>,
13    pub recovered_patches: Vec<String>,
14    pub discarded_patches: Vec<String>,
15}
16
17/// Patch data returned from an external source (e.g., database).
18#[derive(Debug, Clone)]
19pub struct PatchData {
20    pub uuid: String,
21    pub purl: String,
22    pub published_at: String,
23    pub files: HashMap<String, PatchDataFileInfo>,
24    pub vulnerabilities: HashMap<String, PatchDataVulnerability>,
25    pub description: String,
26    pub license: String,
27    pub tier: String,
28}
29
30/// File info from external patch data (hashes are optional).
31#[derive(Debug, Clone)]
32pub struct PatchDataFileInfo {
33    pub before_hash: Option<String>,
34    pub after_hash: Option<String>,
35}
36
37/// Vulnerability info from external patch data.
38#[derive(Debug, Clone)]
39pub struct PatchDataVulnerability {
40    pub cves: Vec<String>,
41    pub summary: String,
42    pub severity: String,
43    pub description: String,
44}
45
46/// Events emitted during recovery.
47#[derive(Debug, Clone)]
48pub enum RecoveryEvent {
49    CorruptedManifest,
50    InvalidPatch {
51        purl: String,
52        uuid: Option<String>,
53    },
54    RecoveredPatch {
55        purl: String,
56        uuid: String,
57    },
58    DiscardedPatchNotFound {
59        purl: String,
60        uuid: String,
61    },
62    DiscardedPatchPurlMismatch {
63        purl: String,
64        uuid: String,
65        db_purl: String,
66    },
67    DiscardedPatchNoUuid {
68        purl: String,
69    },
70    RecoveryError {
71        purl: String,
72        uuid: String,
73        error: String,
74    },
75}
76
77/// Type alias for the refetch callback.
78/// Takes (uuid, optional purl) and returns a future resolving to Option<PatchData>.
79pub type RefetchPatchFn = Box<
80    dyn Fn(String, Option<String>) -> Pin<Box<dyn Future<Output = Result<Option<PatchData>, String>> + Send>>
81        + Send
82        + Sync,
83>;
84
85/// Type alias for the recovery event callback.
86pub type OnRecoveryEventFn = Box<dyn Fn(RecoveryEvent) + Send + Sync>;
87
88/// Options for manifest recovery.
89#[derive(Default)]
90pub struct RecoveryOptions {
91    /// Optional function to refetch patch data from external source (e.g., database).
92    /// Should return patch data or None if not found.
93    pub refetch_patch: Option<RefetchPatchFn>,
94
95    /// Optional callback for logging recovery events.
96    pub on_recovery_event: Option<OnRecoveryEventFn>,
97}
98
99
100/// Recover and validate manifest with automatic repair of invalid patches.
101///
102/// This function attempts to parse and validate a manifest. If the manifest
103/// contains invalid patches, it will attempt to recover them using the provided
104/// refetch function. Patches that cannot be recovered are discarded.
105pub async fn recover_manifest(
106    parsed: &serde_json::Value,
107    options: RecoveryOptions,
108) -> RecoveryResult {
109    let RecoveryOptions {
110        refetch_patch,
111        on_recovery_event,
112    } = options;
113
114    let emit = |event: RecoveryEvent| {
115        if let Some(ref cb) = on_recovery_event {
116            cb(event);
117        }
118    };
119
120    // Try strict parse first (fast path for valid manifests)
121    if let Ok(manifest) = serde_json::from_value::<PatchManifest>(parsed.clone()) {
122        return RecoveryResult {
123            manifest,
124            repair_needed: false,
125            invalid_patches: vec![],
126            recovered_patches: vec![],
127            discarded_patches: vec![],
128        };
129    }
130
131    // Extract patches object with safety checks
132    let patches_obj = parsed
133        .as_object()
134        .and_then(|obj| obj.get("patches"))
135        .and_then(|p| p.as_object());
136
137    let patches_obj = match patches_obj {
138        Some(obj) => obj,
139        None => {
140            // Completely corrupted manifest
141            emit(RecoveryEvent::CorruptedManifest);
142            return RecoveryResult {
143                manifest: PatchManifest::new(),
144                repair_needed: true,
145                invalid_patches: vec![],
146                recovered_patches: vec![],
147                discarded_patches: vec![],
148            };
149        }
150    };
151
152    // Try to recover individual patches
153    let mut recovered_patches_map: HashMap<String, PatchRecord> = HashMap::new();
154    let mut invalid_patches: Vec<String> = Vec::new();
155    let mut recovered_patches: Vec<String> = Vec::new();
156    let mut discarded_patches: Vec<String> = Vec::new();
157
158    for (purl, patch_data) in patches_obj {
159        // Try to parse this individual patch
160        if let Ok(record) = serde_json::from_value::<PatchRecord>(patch_data.clone()) {
161            // Valid patch, keep it as-is
162            recovered_patches_map.insert(purl.clone(), record);
163        } else {
164            // Invalid patch, try to recover from external source
165            let uuid = patch_data
166                .as_object()
167                .and_then(|obj| obj.get("uuid"))
168                .and_then(|v| v.as_str())
169                .map(|s| s.to_string());
170
171            invalid_patches.push(purl.clone());
172            emit(RecoveryEvent::InvalidPatch {
173                purl: purl.clone(),
174                uuid: uuid.clone(),
175            });
176
177            if let (Some(ref uuid_str), Some(ref refetch)) = (&uuid, &refetch_patch) {
178                // Try to refetch from external source
179                match refetch(uuid_str.clone(), Some(purl.clone())).await {
180                    Ok(Some(patch_from_source)) => {
181                        if patch_from_source.purl == *purl {
182                            // Successfully recovered, reconstruct patch record
183                            let mut manifest_files: HashMap<String, PatchFileInfo> =
184                                HashMap::new();
185                            for (file_path, file_info) in &patch_from_source.files {
186                                if let (Some(before), Some(after)) =
187                                    (&file_info.before_hash, &file_info.after_hash)
188                                {
189                                    manifest_files.insert(
190                                        file_path.clone(),
191                                        PatchFileInfo {
192                                            before_hash: before.clone(),
193                                            after_hash: after.clone(),
194                                        },
195                                    );
196                                }
197                            }
198
199                            let mut vulns: HashMap<String, VulnerabilityInfo> = HashMap::new();
200                            for (vuln_id, vuln_data) in &patch_from_source.vulnerabilities {
201                                vulns.insert(
202                                    vuln_id.clone(),
203                                    VulnerabilityInfo {
204                                        cves: vuln_data.cves.clone(),
205                                        summary: vuln_data.summary.clone(),
206                                        severity: vuln_data.severity.clone(),
207                                        description: vuln_data.description.clone(),
208                                    },
209                                );
210                            }
211
212                            recovered_patches_map.insert(
213                                purl.clone(),
214                                PatchRecord {
215                                    uuid: patch_from_source.uuid.clone(),
216                                    exported_at: patch_from_source.published_at.clone(),
217                                    files: manifest_files,
218                                    vulnerabilities: vulns,
219                                    description: patch_from_source.description.clone(),
220                                    license: patch_from_source.license.clone(),
221                                    tier: patch_from_source.tier.clone(),
222                                },
223                            );
224
225                            recovered_patches.push(purl.clone());
226                            emit(RecoveryEvent::RecoveredPatch {
227                                purl: purl.clone(),
228                                uuid: uuid_str.clone(),
229                            });
230                        } else {
231                            // PURL mismatch - wrong package!
232                            discarded_patches.push(purl.clone());
233                            emit(RecoveryEvent::DiscardedPatchPurlMismatch {
234                                purl: purl.clone(),
235                                uuid: uuid_str.clone(),
236                                db_purl: patch_from_source.purl.clone(),
237                            });
238                        }
239                    }
240                    Ok(None) => {
241                        // Not found in external source (might be unpublished)
242                        discarded_patches.push(purl.clone());
243                        emit(RecoveryEvent::DiscardedPatchNotFound {
244                            purl: purl.clone(),
245                            uuid: uuid_str.clone(),
246                        });
247                    }
248                    Err(error_msg) => {
249                        // Error during recovery
250                        discarded_patches.push(purl.clone());
251                        emit(RecoveryEvent::RecoveryError {
252                            purl: purl.clone(),
253                            uuid: uuid_str.clone(),
254                            error: error_msg,
255                        });
256                    }
257                }
258            } else {
259                // No UUID or no refetch function, can't recover
260                discarded_patches.push(purl.clone());
261                if let Some(uuid) = uuid {
262                    emit(RecoveryEvent::DiscardedPatchNotFound {
263                        purl: purl.clone(),
264                        uuid,
265                    });
266                } else {
267                    emit(RecoveryEvent::DiscardedPatchNoUuid {
268                        purl: purl.clone(),
269                    });
270                }
271            }
272        }
273    }
274
275    let repair_needed = !invalid_patches.is_empty();
276
277    RecoveryResult {
278        manifest: PatchManifest {
279            patches: recovered_patches_map,
280        },
281        repair_needed,
282        invalid_patches,
283        recovered_patches,
284        discarded_patches,
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use serde_json::json;
292
293    #[tokio::test]
294    async fn test_valid_manifest_no_repair() {
295        let parsed = json!({
296            "patches": {
297                "pkg:npm/test@1.0.0": {
298                    "uuid": "11111111-1111-4111-8111-111111111111",
299                    "exportedAt": "2024-01-01T00:00:00Z",
300                    "files": {},
301                    "vulnerabilities": {},
302                    "description": "test",
303                    "license": "MIT",
304                    "tier": "free"
305                }
306            }
307        });
308
309        let result = recover_manifest(&parsed, RecoveryOptions::default()).await;
310        assert!(!result.repair_needed);
311        assert_eq!(result.manifest.patches.len(), 1);
312        assert!(result.invalid_patches.is_empty());
313        assert!(result.recovered_patches.is_empty());
314        assert!(result.discarded_patches.is_empty());
315    }
316
317    #[tokio::test]
318    async fn test_corrupted_manifest_no_patches_key() {
319        let parsed = json!({
320            "something": "else"
321        });
322
323        let result = recover_manifest(&parsed, RecoveryOptions::default()).await;
324        assert!(result.repair_needed);
325        assert_eq!(result.manifest.patches.len(), 0);
326    }
327
328    #[tokio::test]
329    async fn test_corrupted_manifest_patches_not_object() {
330        let parsed = json!({
331            "patches": "not-an-object"
332        });
333
334        let result = recover_manifest(&parsed, RecoveryOptions::default()).await;
335        assert!(result.repair_needed);
336        assert_eq!(result.manifest.patches.len(), 0);
337    }
338
339    #[tokio::test]
340    async fn test_invalid_patch_discarded_no_refetch() {
341        let parsed = json!({
342            "patches": {
343                "pkg:npm/test@1.0.0": {
344                    "uuid": "11111111-1111-4111-8111-111111111111"
345                    // missing required fields
346                }
347            }
348        });
349
350        let result = recover_manifest(&parsed, RecoveryOptions::default()).await;
351        assert!(result.repair_needed);
352        assert_eq!(result.manifest.patches.len(), 0);
353        assert_eq!(result.invalid_patches.len(), 1);
354        assert_eq!(result.discarded_patches.len(), 1);
355    }
356
357    #[tokio::test]
358    async fn test_invalid_patch_no_uuid_discarded() {
359        let parsed = json!({
360            "patches": {
361                "pkg:npm/test@1.0.0": {
362                    "garbage": true
363                }
364            }
365        });
366
367
368        let events_clone = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
369        let events_ref = events_clone.clone();
370
371        let options = RecoveryOptions {
372            refetch_patch: None,
373            on_recovery_event: Some(Box::new(move |event| {
374                events_ref.lock().unwrap().push(format!("{:?}", event));
375            })),
376        };
377
378        let result = recover_manifest(&parsed, options).await;
379        assert!(result.repair_needed);
380        assert_eq!(result.discarded_patches.len(), 1);
381
382        let logged = events_clone.lock().unwrap();
383        assert!(logged.iter().any(|e| e.contains("DiscardedPatchNoUuid")));
384    }
385
386    #[tokio::test]
387    async fn test_mix_valid_and_invalid_patches() {
388        let parsed = json!({
389            "patches": {
390                "pkg:npm/good@1.0.0": {
391                    "uuid": "11111111-1111-4111-8111-111111111111",
392                    "exportedAt": "2024-01-01T00:00:00Z",
393                    "files": {},
394                    "vulnerabilities": {},
395                    "description": "good patch",
396                    "license": "MIT",
397                    "tier": "free"
398                },
399                "pkg:npm/bad@1.0.0": {
400                    "uuid": "22222222-2222-4222-8222-222222222222"
401                    // missing required fields
402                }
403            }
404        });
405
406        let result = recover_manifest(&parsed, RecoveryOptions::default()).await;
407        assert!(result.repair_needed);
408        assert_eq!(result.manifest.patches.len(), 1);
409        assert!(result.manifest.patches.contains_key("pkg:npm/good@1.0.0"));
410        assert_eq!(result.invalid_patches.len(), 1);
411        assert_eq!(result.discarded_patches.len(), 1);
412    }
413
414    #[tokio::test]
415    async fn test_recovery_with_refetch_success() {
416        let parsed = json!({
417            "patches": {
418                "pkg:npm/test@1.0.0": {
419                    "uuid": "11111111-1111-4111-8111-111111111111"
420                    // missing required fields
421                }
422            }
423        });
424
425        let options = RecoveryOptions {
426            refetch_patch: Some(Box::new(|_uuid, _purl| {
427                Box::pin(async {
428                    Ok(Some(PatchData {
429                        uuid: "11111111-1111-4111-8111-111111111111".to_string(),
430                        purl: "pkg:npm/test@1.0.0".to_string(),
431                        published_at: "2024-01-01T00:00:00Z".to_string(),
432                        files: {
433                            let mut m = HashMap::new();
434                            m.insert(
435                                "package/index.js".to_string(),
436                                PatchDataFileInfo {
437                                    before_hash: Some("aaa".to_string()),
438                                    after_hash: Some("bbb".to_string()),
439                                },
440                            );
441                            m
442                        },
443                        vulnerabilities: HashMap::new(),
444                        description: "recovered".to_string(),
445                        license: "MIT".to_string(),
446                        tier: "free".to_string(),
447                    }))
448                })
449            })),
450            on_recovery_event: None,
451        };
452
453        let result = recover_manifest(&parsed, options).await;
454        assert!(result.repair_needed);
455        assert_eq!(result.manifest.patches.len(), 1);
456        assert_eq!(result.recovered_patches.len(), 1);
457        assert_eq!(result.discarded_patches.len(), 0);
458
459        let record = result.manifest.patches.get("pkg:npm/test@1.0.0").unwrap();
460        assert_eq!(record.description, "recovered");
461        assert_eq!(record.files.len(), 1);
462    }
463
464    #[tokio::test]
465    async fn test_recovery_with_purl_mismatch() {
466        let parsed = json!({
467            "patches": {
468                "pkg:npm/test@1.0.0": {
469                    "uuid": "11111111-1111-4111-8111-111111111111"
470                }
471            }
472        });
473
474        let options = RecoveryOptions {
475            refetch_patch: Some(Box::new(|_uuid, _purl| {
476                Box::pin(async {
477                    Ok(Some(PatchData {
478                        uuid: "11111111-1111-4111-8111-111111111111".to_string(),
479                        purl: "pkg:npm/other@2.0.0".to_string(), // wrong purl
480                        published_at: "2024-01-01T00:00:00Z".to_string(),
481                        files: HashMap::new(),
482                        vulnerabilities: HashMap::new(),
483                        description: "wrong".to_string(),
484                        license: "MIT".to_string(),
485                        tier: "free".to_string(),
486                    }))
487                })
488            })),
489            on_recovery_event: None,
490        };
491
492        let result = recover_manifest(&parsed, options).await;
493        assert!(result.repair_needed);
494        assert_eq!(result.manifest.patches.len(), 0);
495        assert_eq!(result.discarded_patches.len(), 1);
496    }
497
498    #[tokio::test]
499    async fn test_recovery_with_refetch_not_found() {
500        let parsed = json!({
501            "patches": {
502                "pkg:npm/test@1.0.0": {
503                    "uuid": "11111111-1111-4111-8111-111111111111"
504                }
505            }
506        });
507
508        let options = RecoveryOptions {
509            refetch_patch: Some(Box::new(|_uuid, _purl| {
510                Box::pin(async { Ok(None) })
511            })),
512            on_recovery_event: None,
513        };
514
515        let result = recover_manifest(&parsed, options).await;
516        assert!(result.repair_needed);
517        assert_eq!(result.manifest.patches.len(), 0);
518        assert_eq!(result.discarded_patches.len(), 1);
519    }
520
521    #[tokio::test]
522    async fn test_recovery_with_refetch_error() {
523        let parsed = json!({
524            "patches": {
525                "pkg:npm/test@1.0.0": {
526                    "uuid": "11111111-1111-4111-8111-111111111111"
527                }
528            }
529        });
530
531        let options = RecoveryOptions {
532            refetch_patch: Some(Box::new(|_uuid, _purl| {
533                Box::pin(async { Err("network error".to_string()) })
534            })),
535            on_recovery_event: None,
536        };
537
538        let result = recover_manifest(&parsed, options).await;
539        assert!(result.repair_needed);
540        assert_eq!(result.manifest.patches.len(), 0);
541        assert_eq!(result.discarded_patches.len(), 1);
542    }
543}