Skip to main content

socket_patch_cli/commands/
rollback.rs

1use clap::Args;
2use socket_patch_core::api::blob_fetcher::{
3    fetch_blobs_by_hash, format_fetch_result,
4};
5use socket_patch_core::api::client::get_api_client_with_overrides;
6use socket_patch_core::crawlers::CrawlerOptions;
7use socket_patch_core::manifest::operations::read_manifest;
8use socket_patch_core::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord};
9use socket_patch_core::patch::apply::select_installed_variants;
10use socket_patch_core::patch::rollback::{rollback_package_patch, RollbackResult, VerifyRollbackStatus};
11use socket_patch_core::utils::purl::{purl_matches_identifier, strip_purl_qualifiers};
12use socket_patch_core::utils::telemetry::{track_patch_rolled_back, track_patch_rollback_failed};
13use std::collections::{HashMap, HashSet};
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16
17use crate::args::{apply_env_toggles, GlobalArgs};
18use crate::commands::lock_cli::{acquire_or_emit, LOCK_BROKEN_CODE};
19use crate::ecosystem_dispatch::{find_packages_for_rollback, partition_purls};
20use crate::json_envelope::Command as EnvelopeCommand;
21
22#[derive(Args)]
23pub struct RollbackArgs {
24    /// Package PURL or patch UUID to rollback. Omit to rollback all patches.
25    pub identifier: Option<String>,
26
27    #[command(flatten)]
28    pub common: GlobalArgs,
29
30    /// Rollback a patch by fetching beforeHash blobs from API (no manifest required).
31    #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
32    pub one_off: bool,
33}
34
35struct PatchToRollback {
36    purl: String,
37    patch: PatchRecord,
38}
39
40fn find_patches_to_rollback(
41    manifest: &PatchManifest,
42    identifier: Option<&str>,
43) -> Vec<PatchToRollback> {
44    match identifier {
45        None => manifest
46            .patches
47            .iter()
48            .map(|(purl, patch)| PatchToRollback {
49                purl: purl.clone(),
50                patch: patch.clone(),
51            })
52            .collect(),
53        Some(id) => {
54            let mut patches = Vec::new();
55            if id.starts_with("pkg:") {
56                // A base PURL (no `?`) matches every release variant of
57                // that package@version; a qualified PURL targets one.
58                for (purl, patch) in &manifest.patches {
59                    if purl_matches_identifier(purl, id) {
60                        patches.push(PatchToRollback {
61                            purl: purl.clone(),
62                            patch: patch.clone(),
63                        });
64                    }
65                }
66            } else {
67                for (purl, patch) in &manifest.patches {
68                    if patch.uuid == id {
69                        patches.push(PatchToRollback {
70                            purl: purl.clone(),
71                            patch: patch.clone(),
72                        });
73                    }
74                }
75            }
76            patches
77        }
78    }
79}
80
81fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
82    let mut blobs = HashSet::new();
83    for patch in manifest.patches.values() {
84        for file_info in patch.files.values() {
85            blobs.insert(file_info.before_hash.clone());
86        }
87    }
88    blobs
89}
90
91async fn get_missing_before_blobs(
92    manifest: &PatchManifest,
93    blobs_path: &Path,
94) -> HashSet<String> {
95    let before_blobs = get_before_hash_blobs(manifest);
96    let mut missing = HashSet::new();
97    for hash in before_blobs {
98        let blob_path = blobs_path.join(&hash);
99        if tokio::fs::metadata(&blob_path).await.is_err() {
100            missing.insert(hash);
101        }
102    }
103    missing
104}
105
106fn verify_rollback_status_str(status: &VerifyRollbackStatus) -> &'static str {
107    match status {
108        VerifyRollbackStatus::Ready => "ready",
109        VerifyRollbackStatus::AlreadyOriginal => "already_original",
110        VerifyRollbackStatus::HashMismatch => "hash_mismatch",
111        VerifyRollbackStatus::NotFound => "not_found",
112        VerifyRollbackStatus::MissingBlob => "missing_blob",
113    }
114}
115
116/// True when every file the engine verified for this package is already
117/// at its original (`beforeHash`) state — i.e. the rollback is a complete
118/// no-op on disk.
119///
120/// This is the rollback-side mirror of apply's `all_files_already_patched`.
121/// The `!is_empty()` guard is essential: `Iterator::all` over an empty
122/// slice is vacuously `true`. Without it a result with no verified files
123/// — a zero-file patch record, or a result whose `files_verified` came
124/// back empty — would be mislabeled "already original" and miscounted as
125/// a no-op even though nothing matched `beforeHash`.
126fn all_files_already_original(result: &RollbackResult) -> bool {
127    !result.files_verified.is_empty()
128        && result
129            .files_verified
130            .iter()
131            .all(|f| f.status == VerifyRollbackStatus::AlreadyOriginal)
132}
133
134/// Number of packages that have files which actually need restoring,
135/// used by the dry-run summary. Successful-but-already-original packages
136/// are no-ops reported on their own line, so they are excluded here —
137/// mirroring apply's dry-run split — to avoid double-counting them
138/// against "can be rolled back".
139fn can_rollback_count(results: &[RollbackResult]) -> usize {
140    let successful = results.iter().filter(|r| r.success).count();
141    let already_original = results
142        .iter()
143        .filter(|r| r.success && all_files_already_original(r))
144        .count();
145    successful.saturating_sub(already_original)
146}
147
148fn result_to_json(result: &RollbackResult) -> serde_json::Value {
149    serde_json::json!({
150        "purl": result.package_key,
151        "path": result.package_path,
152        "success": result.success,
153        "error": result.error,
154        "filesRolledBack": result.files_rolled_back,
155        "filesVerified": result.files_verified.iter().map(|f| {
156            serde_json::json!({
157                "file": f.file,
158                "status": verify_rollback_status_str(&f.status),
159                "message": f.message,
160                "currentHash": f.current_hash,
161                "expectedHash": f.expected_hash,
162                "targetHash": f.target_hash,
163            })
164        }).collect::<Vec<_>>(),
165    })
166}
167
168pub async fn run(args: RollbackArgs) -> i32 {
169    apply_env_toggles(&args.common);
170
171    let (telemetry_client, _) =
172        get_api_client_with_overrides(args.common.api_client_overrides()).await;
173    let api_token = telemetry_client.api_token().cloned();
174    let org_slug = telemetry_client.org_slug().cloned();
175
176    // Validate one-off requires identifier
177    if args.one_off && args.identifier.is_none() {
178        if args.common.json {
179            println!("{}", serde_json::to_string_pretty(&serde_json::json!({
180                "status": "error",
181                "error": "--one-off requires an identifier (UUID or PURL)",
182            })).unwrap());
183        } else {
184            eprintln!("Error: --one-off requires an identifier (UUID or PURL)");
185        }
186        return 1;
187    }
188
189    // Handle one-off mode
190    if args.one_off {
191        if args.common.json {
192            println!("{}", serde_json::to_string_pretty(&serde_json::json!({
193                "status": "error",
194                "error": "One-off rollback mode is not yet implemented",
195            })).unwrap());
196        } else {
197            eprintln!("One-off rollback mode: fetching patch data...");
198        }
199        return 1;
200    }
201
202    let manifest_path = args.common.resolved_manifest_path();
203
204    if tokio::fs::metadata(&manifest_path).await.is_err() {
205        if args.common.json {
206            println!("{}", serde_json::to_string_pretty(&serde_json::json!({
207                "status": "error",
208                "error": "Manifest not found",
209                "path": manifest_path.display().to_string(),
210            })).unwrap());
211        } else if !args.common.silent {
212            eprintln!("Manifest not found at {}", manifest_path.display());
213        }
214        return 1;
215    }
216
217    // Serialize against concurrent socket-patch runs targeting the
218    // same `.socket/` directory. See
219    // `socket_patch_core::patch::apply_lock`.
220    let socket_dir = manifest_path.parent().unwrap_or(Path::new("."));
221    let acquired = match acquire_or_emit(
222        socket_dir,
223        EnvelopeCommand::Rollback,
224        args.common.json,
225        args.common.silent,
226        args.common.dry_run,
227        Duration::from_secs(args.common.lock_timeout.unwrap_or(0)),
228        args.common.break_lock,
229    ) {
230        Ok(acquired) => acquired,
231        Err(code) => return code,
232    };
233    let _lock = acquired.guard;
234    let lock_was_broken = acquired.broke_lock;
235
236    match rollback_patches_inner(&args, &manifest_path).await {
237        Ok((success, results)) => {
238            let rolled_back_count = results
239                .iter()
240                .filter(|r| r.success && !r.files_rolled_back.is_empty())
241                .count();
242            let already_original_count = results
243                .iter()
244                .filter(|r| r.success && all_files_already_original(r))
245                .count();
246            let failed_count = results.iter().filter(|r| !r.success).count();
247
248            if args.common.json {
249                // `warnings` carries non-fatal audit info — currently
250                // just the `lock_broken` notice when --break-lock fired.
251                // Empty array stays present in the JSON shape so
252                // consumers can rely on `.warnings[]` without
253                // null-checking.
254                let mut warnings = Vec::new();
255                if lock_was_broken {
256                    warnings.push(serde_json::json!({
257                        "code": LOCK_BROKEN_CODE,
258                        "message": format!(
259                            "--break-lock removed {}/apply.lock before acquisition",
260                            socket_dir.display()
261                        ),
262                    }));
263                }
264                println!("{}", serde_json::to_string_pretty(&serde_json::json!({
265                    "status": if success { "success" } else { "partial_failure" },
266                    "rolledBack": rolled_back_count,
267                    "alreadyOriginal": already_original_count,
268                    "failed": failed_count,
269                    "dryRun": args.common.dry_run,
270                    "warnings": warnings,
271                    "results": results.iter().map(result_to_json).collect::<Vec<_>>(),
272                })).unwrap());
273            } else if !args.common.silent && !results.is_empty() {
274                let rolled_back: Vec<_> = results
275                    .iter()
276                    .filter(|r| r.success && !r.files_rolled_back.is_empty())
277                    .collect();
278                let already_original: Vec<_> = results
279                    .iter()
280                    .filter(|r| r.success && all_files_already_original(r))
281                    .collect();
282                let failed: Vec<_> = results.iter().filter(|r| !r.success).collect();
283
284                if args.common.dry_run {
285                    println!("\nRollback verification complete:");
286                    // Exclude already-original packages — they are
287                    // reported separately just below, so counting them
288                    // here too would double-report each no-op.
289                    let can_rollback = can_rollback_count(&results);
290                    println!("  {can_rollback} package(s) can be rolled back");
291                    if !already_original.is_empty() {
292                        println!(
293                            "  {} package(s) already in original state",
294                            already_original.len()
295                        );
296                    }
297                    if !failed.is_empty() {
298                        println!("  {} package(s) cannot be rolled back", failed.len());
299                    }
300                } else {
301                    if !rolled_back.is_empty() || !already_original.is_empty() {
302                        println!("\nRolled back packages:");
303                        for result in &rolled_back {
304                            println!("  {}", result.package_key);
305                        }
306                        for result in &already_original {
307                            println!("  {} (already original)", result.package_key);
308                        }
309                    }
310                    if !failed.is_empty() {
311                        println!("\nFailed to rollback:");
312                        for result in &failed {
313                            println!(
314                                "  {}: {}",
315                                result.package_key,
316                                result.error.as_deref().unwrap_or("unknown error")
317                            );
318                        }
319                    }
320                }
321
322                if args.common.verbose {
323                    println!("\nDetailed verification:");
324                    for result in &results {
325                        println!("  {}:", result.package_key);
326                        for f in &result.files_verified {
327                            let status_str = match f.status {
328                                VerifyRollbackStatus::Ready => "ready",
329                                VerifyRollbackStatus::AlreadyOriginal => "already original",
330                                VerifyRollbackStatus::HashMismatch => "hash mismatch",
331                                VerifyRollbackStatus::NotFound => "not found",
332                                VerifyRollbackStatus::MissingBlob => "missing blob",
333                            };
334                            println!("    {} [{}]", f.file, status_str);
335                            if let Some(ref msg) = f.message {
336                                println!("      message: {msg}");
337                            }
338                            if let Some(ref h) = f.current_hash {
339                                println!("      current:  {h}");
340                            }
341                            if let Some(ref h) = f.expected_hash {
342                                println!("      expected: {h}");
343                            }
344                            if let Some(ref h) = f.target_hash {
345                                println!("      target:   {h}");
346                            }
347                        }
348                    }
349                }
350            }
351
352            if success {
353                track_patch_rolled_back(rolled_back_count, api_token.as_deref(), org_slug.as_deref()).await;
354            } else {
355                track_patch_rollback_failed("One or more rollbacks failed", api_token.as_deref(), org_slug.as_deref()).await;
356            }
357
358            if success { 0 } else { 1 }
359        }
360        Err(e) => {
361            track_patch_rollback_failed(&e, api_token.as_deref(), org_slug.as_deref()).await;
362            if args.common.json {
363                println!("{}", serde_json::to_string_pretty(&serde_json::json!({
364                    "status": "error",
365                    "error": e,
366                    "rolledBack": 0,
367                    "alreadyOriginal": 0,
368                    "failed": 0,
369                    "dryRun": args.common.dry_run,
370                    "results": [],
371                })).unwrap());
372            } else if !args.common.silent {
373                eprintln!("Error: {e}");
374            }
375            1
376        }
377    }
378}
379
380async fn rollback_patches_inner(
381    args: &RollbackArgs,
382    manifest_path: &Path,
383) -> Result<(bool, Vec<RollbackResult>), String> {
384    let manifest = read_manifest(manifest_path)
385        .await
386        .map_err(|e| e.to_string())?
387        .ok_or_else(|| "Invalid manifest".to_string())?;
388
389    let socket_dir = manifest_path.parent().unwrap();
390    let blobs_path = socket_dir.join("blobs");
391    tokio::fs::create_dir_all(&blobs_path)
392        .await
393        .map_err(|e| e.to_string())?;
394
395    let patches_to_rollback =
396        find_patches_to_rollback(&manifest, args.identifier.as_deref());
397
398    if patches_to_rollback.is_empty() {
399        if args.identifier.is_some() {
400            return Err(format!(
401                "No patch found matching identifier: {}",
402                args.identifier.as_deref().unwrap()
403            ));
404        }
405        if !args.common.silent && !args.common.json {
406            println!("No patches found in manifest");
407        }
408        return Ok((true, Vec::new()));
409    }
410
411    // Create filtered manifest
412    let filtered_manifest = PatchManifest {
413        patches: patches_to_rollback
414            .iter()
415            .map(|p| (p.purl.clone(), p.patch.clone()))
416            .collect(),
417    };
418
419    // Check for missing beforeHash blobs
420    let missing_blobs = get_missing_before_blobs(&filtered_manifest, &blobs_path).await;
421    if !missing_blobs.is_empty() {
422        if args.common.offline {
423            if !args.common.silent && !args.common.json {
424                eprintln!(
425                    "Error: {} blob(s) are missing and --offline mode is enabled.",
426                    missing_blobs.len()
427                );
428                eprintln!("Run \"socket-patch repair\" to download missing blobs.");
429            }
430            return Ok((false, Vec::new()));
431        }
432
433        if !args.common.silent && !args.common.json {
434            println!("Downloading {} missing blob(s)...", missing_blobs.len());
435        }
436
437        let (client, _) =
438            get_api_client_with_overrides(args.common.api_client_overrides()).await;
439        let fetch_result = fetch_blobs_by_hash(&missing_blobs, &blobs_path, &client, None).await;
440
441        if !args.common.silent && !args.common.json {
442            println!("{}", format_fetch_result(&fetch_result));
443        }
444
445        let still_missing = get_missing_before_blobs(&filtered_manifest, &blobs_path).await;
446        if !still_missing.is_empty() {
447            if !args.common.silent && !args.common.json {
448                eprintln!(
449                    "{} blob(s) could not be downloaded. Cannot rollback.",
450                    still_missing.len()
451                );
452            }
453            return Ok((false, Vec::new()));
454        }
455    }
456
457    // Partition PURLs by ecosystem
458    let rollback_purls: Vec<String> = patches_to_rollback.iter().map(|p| p.purl.clone()).collect();
459    let partitioned =
460        partition_purls(&rollback_purls, args.common.ecosystems.as_deref());
461
462    let crawler_options = CrawlerOptions {
463        cwd: args.common.cwd.clone(),
464        global: args.common.global,
465        global_prefix: args.common.global_prefix.clone(),
466        batch_size: 100,
467    };
468
469    let all_packages =
470        find_packages_for_rollback(&partitioned, &crawler_options, args.common.silent || args.common.json).await;
471
472    if all_packages.is_empty() {
473        if !args.common.silent && !args.common.json {
474            println!("No packages found that match patches to rollback");
475        }
476        return Ok((true, Vec::new()));
477    }
478
479    // Group discovered packages by base PURL. A release-variant
480    // `package@version` (PyPI/RubyGems/Maven) may have several variants
481    // in the manifest that `merge_qualified` resolves to the same
482    // installed package dir. Rolling back a variant that is *not* present
483    // on disk would HashMismatch and report a spurious failure, so —
484    // mirroring apply — we collapse each group to the variant(s) whose
485    // hashes actually match the installed bytes. PyPI/RubyGems yield one
486    // such variant; Maven's coexisting classifier jars may yield several.
487    let mut groups: HashMap<String, Vec<(&String, &PathBuf)>> = HashMap::new();
488    for (purl, pkg_path) in &all_packages {
489        groups
490            .entry(strip_purl_qualifiers(purl).to_string())
491            .or_default()
492            .push((purl, pkg_path));
493    }
494
495    // Rollback patches
496    let mut results: Vec<RollbackResult> = Vec::new();
497    let mut has_errors = false;
498
499    for (_base, entries) in groups {
500        // Resolve which variant(s) to roll back for this base PURL.
501        let to_rollback: Vec<(&String, &PathBuf)> = if entries.len() == 1 {
502            entries
503        } else {
504            // All variants in a group resolve to the same installed path.
505            let pkg_path = entries[0].1;
506            let candidates: Vec<(&str, &HashMap<String, PatchFileInfo>)> = entries
507                .iter()
508                .filter_map(|(purl, _)| {
509                    filtered_manifest
510                        .patches
511                        .get(*purl)
512                        .map(|p| (purl.as_str(), &p.files))
513                })
514                .collect();
515            let matched = select_installed_variants(pkg_path, &candidates).await;
516            if matched.is_empty() {
517                // No variant matches the installed distribution (e.g. a
518                // locally-modified file). Fall back to attempting every
519                // variant so the per-file verification surfaces the
520                // mismatch rather than silently skipping the package.
521                entries
522            } else {
523                let winners: HashSet<String> =
524                    matched.iter().map(|&i| candidates[i].0.to_string()).collect();
525                entries
526                    .into_iter()
527                    .filter(|(p, _)| winners.contains(*p))
528                    .collect()
529            }
530        };
531
532        for (purl, pkg_path) in to_rollback {
533            let patch = match filtered_manifest.patches.get(purl) {
534                Some(p) => p,
535                None => continue,
536            };
537
538            let result = rollback_package_patch(
539                purl,
540                pkg_path,
541                &patch.files,
542                &blobs_path,
543                args.common.dry_run,
544            )
545            .await;
546
547            if !result.success {
548                has_errors = true;
549                if !args.common.silent && !args.common.json {
550                    eprintln!(
551                        "Failed to rollback {}: {}",
552                        purl,
553                        result.error.as_deref().unwrap_or("unknown error")
554                    );
555                }
556            }
557            results.push(result);
558        }
559    }
560
561    Ok((!has_errors, results))
562}
563
564// Export for use by remove command
565#[allow(clippy::too_many_arguments)]
566pub async fn rollback_patches(
567    cwd: &Path,
568    manifest_path: &Path,
569    identifier: Option<&str>,
570    dry_run: bool,
571    silent: bool,
572    offline: bool,
573    global: bool,
574    global_prefix: Option<PathBuf>,
575    ecosystems: Option<Vec<String>>,
576) -> Result<(bool, Vec<RollbackResult>), String> {
577    let args = RollbackArgs {
578        identifier: identifier.map(String::from),
579        common: crate::args::GlobalArgs {
580            cwd: cwd.to_path_buf(),
581            manifest_path: manifest_path.display().to_string(),
582            offline,
583            global,
584            global_prefix,
585            ecosystems,
586            silent,
587            dry_run,
588            ..crate::args::GlobalArgs::default()
589        },
590        one_off: false,
591    };
592    rollback_patches_inner(&args, manifest_path).await
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord};
599    use std::collections::HashMap;
600
601    fn make_record(uuid: &str) -> PatchRecord {
602        PatchRecord {
603            uuid: uuid.to_string(),
604            exported_at: "2024-01-01T00:00:00Z".to_string(),
605            files: HashMap::new(),
606            vulnerabilities: HashMap::new(),
607            description: "test patch".to_string(),
608            license: "MIT".to_string(),
609            tier: "free".to_string(),
610        }
611    }
612
613    fn make_manifest() -> PatchManifest {
614        let mut patches = HashMap::new();
615        patches.insert("pkg:npm/foo@1.0".to_string(), make_record("uuid-foo"));
616        patches.insert("pkg:npm/bar@2.0".to_string(), make_record("uuid-bar"));
617        patches.insert("pkg:pypi/baz@3.0".to_string(), make_record("uuid-baz"));
618        PatchManifest { patches }
619    }
620
621    #[test]
622    fn test_find_patches_to_rollback_none_returns_all() {
623        let manifest = make_manifest();
624        let result = find_patches_to_rollback(&manifest, None);
625        assert_eq!(result.len(), 3);
626    }
627
628    #[test]
629    fn test_find_patches_to_rollback_purl_match() {
630        let manifest = make_manifest();
631        let result =
632            find_patches_to_rollback(&manifest, Some("pkg:npm/foo@1.0"));
633        assert_eq!(result.len(), 1);
634        assert_eq!(result[0].purl, "pkg:npm/foo@1.0");
635    }
636
637    #[test]
638    fn test_find_patches_to_rollback_purl_no_match() {
639        let manifest = make_manifest();
640        let result =
641            find_patches_to_rollback(&manifest, Some("pkg:npm/nonexistent@1"));
642        assert!(result.is_empty());
643    }
644
645    #[test]
646    fn test_find_patches_to_rollback_uuid_match() {
647        let manifest = make_manifest();
648        let result = find_patches_to_rollback(&manifest, Some("uuid-bar"));
649        assert_eq!(result.len(), 1);
650        assert_eq!(result[0].patch.uuid, "uuid-bar");
651        assert_eq!(result[0].purl, "pkg:npm/bar@2.0");
652    }
653
654    #[test]
655    fn test_find_patches_to_rollback_uuid_no_match() {
656        let manifest = make_manifest();
657        let result =
658            find_patches_to_rollback(&manifest, Some("uuid-does-not-exist"));
659        assert!(result.is_empty());
660    }
661
662    /// A manifest holding several PyPI release variants of one
663    /// package@version (broad mode).
664    fn make_multi_variant_manifest() -> PatchManifest {
665        let mut patches = HashMap::new();
666        patches.insert(
667            "pkg:pypi/six@1.16.0?artifact_id=wheel-cp311".to_string(),
668            make_record("uuid-wheel-cp311"),
669        );
670        patches.insert(
671            "pkg:pypi/six@1.16.0?artifact_id=wheel-cp312".to_string(),
672            make_record("uuid-wheel-cp312"),
673        );
674        patches.insert(
675            "pkg:pypi/six@1.16.0?artifact_id=sdist".to_string(),
676            make_record("uuid-sdist"),
677        );
678        patches.insert("pkg:npm/foo@1.0".to_string(), make_record("uuid-foo"));
679        PatchManifest { patches }
680    }
681
682    #[test]
683    fn test_find_patches_to_rollback_base_purl_matches_all_variants() {
684        let manifest = make_multi_variant_manifest();
685        let result =
686            find_patches_to_rollback(&manifest, Some("pkg:pypi/six@1.16.0"));
687        // Base PURL (no qualifier) expands to every release variant.
688        assert_eq!(result.len(), 3);
689        for p in &result {
690            assert!(p.purl.starts_with("pkg:pypi/six@1.16.0?artifact_id="));
691        }
692    }
693
694    #[test]
695    fn test_find_patches_to_rollback_qualified_purl_matches_one_variant() {
696        let manifest = make_multi_variant_manifest();
697        let result = find_patches_to_rollback(
698            &manifest,
699            Some("pkg:pypi/six@1.16.0?artifact_id=sdist"),
700        );
701        // A fully-qualified PURL targets exactly one variant.
702        assert_eq!(result.len(), 1);
703        assert_eq!(result[0].purl, "pkg:pypi/six@1.16.0?artifact_id=sdist");
704    }
705
706    #[test]
707    fn test_find_patches_to_rollback_base_purl_does_not_leak_other_packages() {
708        let manifest = make_multi_variant_manifest();
709        let result =
710            find_patches_to_rollback(&manifest, Some("pkg:pypi/six@1.16.0"));
711        assert!(result.iter().all(|p| p.purl.contains("six@1.16.0")));
712    }
713
714    // --- Summary-counting regressions -----------------------------------
715    //
716    // These pin the rollback summary to the same contract apply uses:
717    // an "already original" result must have at least one verified file,
718    // and the dry-run "can be rolled back" count must not double-report
719    // packages that are already in their original state.
720
721    use socket_patch_core::patch::rollback::VerifyRollbackResult;
722
723    fn verified(status: VerifyRollbackStatus) -> VerifyRollbackResult {
724        VerifyRollbackResult {
725            file: "package/index.js".to_string(),
726            status,
727            message: None,
728            current_hash: None,
729            expected_hash: None,
730            target_hash: None,
731        }
732    }
733
734    /// Build a `RollbackResult` from verification statuses and the list of
735    /// files reported rolled back. `success` defaults to whether every
736    /// verified file is Ready/AlreadyOriginal, matching the engine.
737    fn make_result(
738        verified_statuses: &[VerifyRollbackStatus],
739        rolled_back: &[&str],
740    ) -> RollbackResult {
741        let files_verified: Vec<_> =
742            verified_statuses.iter().cloned().map(verified).collect();
743        let success = files_verified.iter().all(|f| {
744            f.status == VerifyRollbackStatus::Ready
745                || f.status == VerifyRollbackStatus::AlreadyOriginal
746        });
747        RollbackResult {
748            package_key: "pkg:npm/foo@1.0.0".to_string(),
749            package_path: "/tmp/foo".to_string(),
750            success,
751            files_verified,
752            files_rolled_back: rolled_back.iter().map(|s| s.to_string()).collect(),
753            error: None,
754        }
755    }
756
757    #[test]
758    fn all_files_already_original_true_when_every_file_matches() {
759        let r = make_result(
760            &[
761                VerifyRollbackStatus::AlreadyOriginal,
762                VerifyRollbackStatus::AlreadyOriginal,
763            ],
764            &[],
765        );
766        assert!(all_files_already_original(&r));
767    }
768
769    #[test]
770    fn all_files_already_original_false_when_any_file_differs() {
771        let r = make_result(
772            &[
773                VerifyRollbackStatus::AlreadyOriginal,
774                VerifyRollbackStatus::Ready,
775            ],
776            &[],
777        );
778        assert!(!all_files_already_original(&r));
779    }
780
781    /// Regression: `Iterator::all` over an empty slice is vacuously true.
782    /// A successful result with no verified files (a zero-file patch
783    /// record) must NOT be reported as "already original" — the
784    /// `!is_empty()` guard enforces this, matching apply.
785    #[test]
786    fn all_files_already_original_false_when_no_verified_files() {
787        let r = make_result(&[], &[]);
788        assert!(r.files_verified.is_empty());
789        assert!(r.success);
790        assert!(!all_files_already_original(&r));
791    }
792
793    /// Regression: the dry-run "can be rolled back" count must exclude
794    /// already-original packages, which are reported on their own line.
795    /// Otherwise each no-op is double-counted (once as can-rollback, once
796    /// as already-original).
797    #[test]
798    fn can_rollback_count_excludes_already_original() {
799        let results = vec![
800            // Genuinely needs restoring.
801            make_result(&[VerifyRollbackStatus::Ready], &[]),
802            // No-op: already at beforeHash.
803            make_result(&[VerifyRollbackStatus::AlreadyOriginal], &[]),
804            // Mixed → still needs restoring.
805            make_result(
806                &[
807                    VerifyRollbackStatus::Ready,
808                    VerifyRollbackStatus::AlreadyOriginal,
809                ],
810                &[],
811            ),
812            // Failed (e.g. HashMismatch) → not counted as rollbackable.
813            make_result(&[VerifyRollbackStatus::HashMismatch], &[]),
814        ];
815        // 2 successful non-no-op packages; the already-original one is
816        // excluded and the failed one was never successful.
817        assert_eq!(can_rollback_count(&results), 2);
818    }
819
820    /// A summary made entirely of no-ops reports zero rollbackable
821    /// packages (and `saturating_sub` keeps it from underflowing).
822    #[test]
823    fn can_rollback_count_all_already_original_is_zero() {
824        let results = vec![
825            make_result(&[VerifyRollbackStatus::AlreadyOriginal], &[]),
826            make_result(&[VerifyRollbackStatus::AlreadyOriginal], &[]),
827        ];
828        assert_eq!(can_rollback_count(&results), 0);
829    }
830}