Skip to main content

socket_patch_cli/commands/
rollback.rs

1use clap::Args;
2use socket_patch_core::api::blob_fetcher::{
3    fetch_blobs_by_hash, format_fetch_result,
4};
5use socket_patch_core::api::client::get_api_client_with_overrides;
6use socket_patch_core::crawlers::CrawlerOptions;
7use socket_patch_core::manifest::operations::read_manifest;
8use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord};
9use socket_patch_core::patch::rollback::{rollback_package_patch, RollbackResult, VerifyRollbackStatus};
10use socket_patch_core::utils::telemetry::{track_patch_rolled_back, track_patch_rollback_failed};
11use std::collections::HashSet;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14
15use crate::args::{apply_env_toggles, GlobalArgs};
16use crate::commands::lock_cli::{acquire_or_emit, LOCK_BROKEN_CODE};
17use crate::ecosystem_dispatch::{find_packages_for_rollback, partition_purls};
18use crate::json_envelope::Command as EnvelopeCommand;
19
20#[derive(Args)]
21pub struct RollbackArgs {
22    /// Package PURL or patch UUID to rollback. Omit to rollback all patches.
23    pub identifier: Option<String>,
24
25    #[command(flatten)]
26    pub common: GlobalArgs,
27
28    /// Rollback a patch by fetching beforeHash blobs from API (no manifest required).
29    #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
30    pub one_off: bool,
31}
32
33struct PatchToRollback {
34    purl: String,
35    patch: PatchRecord,
36}
37
38fn find_patches_to_rollback(
39    manifest: &PatchManifest,
40    identifier: Option<&str>,
41) -> Vec<PatchToRollback> {
42    match identifier {
43        None => manifest
44            .patches
45            .iter()
46            .map(|(purl, patch)| PatchToRollback {
47                purl: purl.clone(),
48                patch: patch.clone(),
49            })
50            .collect(),
51        Some(id) => {
52            let mut patches = Vec::new();
53            if id.starts_with("pkg:") {
54                if let Some(patch) = manifest.patches.get(id) {
55                    patches.push(PatchToRollback {
56                        purl: id.to_string(),
57                        patch: patch.clone(),
58                    });
59                }
60            } else {
61                for (purl, patch) in &manifest.patches {
62                    if patch.uuid == id {
63                        patches.push(PatchToRollback {
64                            purl: purl.clone(),
65                            patch: patch.clone(),
66                        });
67                    }
68                }
69            }
70            patches
71        }
72    }
73}
74
75fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
76    let mut blobs = HashSet::new();
77    for patch in manifest.patches.values() {
78        for file_info in patch.files.values() {
79            blobs.insert(file_info.before_hash.clone());
80        }
81    }
82    blobs
83}
84
85async fn get_missing_before_blobs(
86    manifest: &PatchManifest,
87    blobs_path: &Path,
88) -> HashSet<String> {
89    let before_blobs = get_before_hash_blobs(manifest);
90    let mut missing = HashSet::new();
91    for hash in before_blobs {
92        let blob_path = blobs_path.join(&hash);
93        if tokio::fs::metadata(&blob_path).await.is_err() {
94            missing.insert(hash);
95        }
96    }
97    missing
98}
99
100fn verify_rollback_status_str(status: &VerifyRollbackStatus) -> &'static str {
101    match status {
102        VerifyRollbackStatus::Ready => "ready",
103        VerifyRollbackStatus::AlreadyOriginal => "already_original",
104        VerifyRollbackStatus::HashMismatch => "hash_mismatch",
105        VerifyRollbackStatus::NotFound => "not_found",
106        VerifyRollbackStatus::MissingBlob => "missing_blob",
107    }
108}
109
110fn result_to_json(result: &RollbackResult) -> serde_json::Value {
111    serde_json::json!({
112        "purl": result.package_key,
113        "path": result.package_path,
114        "success": result.success,
115        "error": result.error,
116        "filesRolledBack": result.files_rolled_back,
117        "filesVerified": result.files_verified.iter().map(|f| {
118            serde_json::json!({
119                "file": f.file,
120                "status": verify_rollback_status_str(&f.status),
121                "message": f.message,
122                "currentHash": f.current_hash,
123                "expectedHash": f.expected_hash,
124                "targetHash": f.target_hash,
125            })
126        }).collect::<Vec<_>>(),
127    })
128}
129
130pub async fn run(args: RollbackArgs) -> i32 {
131    apply_env_toggles(&args.common);
132
133    let (telemetry_client, _) =
134        get_api_client_with_overrides(args.common.api_client_overrides()).await;
135    let api_token = telemetry_client.api_token().cloned();
136    let org_slug = telemetry_client.org_slug().cloned();
137
138    // Validate one-off requires identifier
139    if args.one_off && args.identifier.is_none() {
140        if args.common.json {
141            println!("{}", serde_json::to_string_pretty(&serde_json::json!({
142                "status": "error",
143                "error": "--one-off requires an identifier (UUID or PURL)",
144            })).unwrap());
145        } else {
146            eprintln!("Error: --one-off requires an identifier (UUID or PURL)");
147        }
148        return 1;
149    }
150
151    // Handle one-off mode
152    if args.one_off {
153        if args.common.json {
154            println!("{}", serde_json::to_string_pretty(&serde_json::json!({
155                "status": "error",
156                "error": "One-off rollback mode is not yet implemented",
157            })).unwrap());
158        } else {
159            eprintln!("One-off rollback mode: fetching patch data...");
160        }
161        return 1;
162    }
163
164    let manifest_path = args.common.resolved_manifest_path();
165
166    if tokio::fs::metadata(&manifest_path).await.is_err() {
167        if args.common.json {
168            println!("{}", serde_json::to_string_pretty(&serde_json::json!({
169                "status": "error",
170                "error": "Manifest not found",
171                "path": manifest_path.display().to_string(),
172            })).unwrap());
173        } else if !args.common.silent {
174            eprintln!("Manifest not found at {}", manifest_path.display());
175        }
176        return 1;
177    }
178
179    // Serialize against concurrent socket-patch runs targeting the
180    // same `.socket/` directory. See
181    // `socket_patch_core::patch::apply_lock`.
182    let socket_dir = manifest_path.parent().unwrap_or(Path::new("."));
183    let acquired = match acquire_or_emit(
184        socket_dir,
185        EnvelopeCommand::Rollback,
186        args.common.json,
187        args.common.silent,
188        args.common.dry_run,
189        Duration::from_secs(args.common.lock_timeout.unwrap_or(0)),
190        args.common.break_lock,
191    ) {
192        Ok(acquired) => acquired,
193        Err(code) => return code,
194    };
195    let _lock = acquired.guard;
196    let lock_was_broken = acquired.broke_lock;
197
198    match rollback_patches_inner(&args, &manifest_path).await {
199        Ok((success, results)) => {
200            let rolled_back_count = results
201                .iter()
202                .filter(|r| r.success && !r.files_rolled_back.is_empty())
203                .count();
204            let already_original_count = results
205                .iter()
206                .filter(|r| {
207                    r.success
208                        && r.files_verified.iter().all(|f| {
209                            f.status == VerifyRollbackStatus::AlreadyOriginal
210                        })
211                })
212                .count();
213            let failed_count = results.iter().filter(|r| !r.success).count();
214
215            if args.common.json {
216                // `warnings` carries non-fatal audit info — currently
217                // just the `lock_broken` notice when --break-lock fired.
218                // Empty array stays present in the JSON shape so
219                // consumers can rely on `.warnings[]` without
220                // null-checking.
221                let mut warnings = Vec::new();
222                if lock_was_broken {
223                    warnings.push(serde_json::json!({
224                        "code": LOCK_BROKEN_CODE,
225                        "message": format!(
226                            "--break-lock removed {}/apply.lock before acquisition",
227                            socket_dir.display()
228                        ),
229                    }));
230                }
231                println!("{}", serde_json::to_string_pretty(&serde_json::json!({
232                    "status": if success { "success" } else { "partial_failure" },
233                    "rolledBack": rolled_back_count,
234                    "alreadyOriginal": already_original_count,
235                    "failed": failed_count,
236                    "dryRun": args.common.dry_run,
237                    "warnings": warnings,
238                    "results": results.iter().map(result_to_json).collect::<Vec<_>>(),
239                })).unwrap());
240            } else if !args.common.silent && !results.is_empty() {
241                let rolled_back: Vec<_> = results
242                    .iter()
243                    .filter(|r| r.success && !r.files_rolled_back.is_empty())
244                    .collect();
245                let already_original: Vec<_> = results
246                    .iter()
247                    .filter(|r| {
248                        r.success
249                            && r.files_verified.iter().all(|f| {
250                                f.status == VerifyRollbackStatus::AlreadyOriginal
251                            })
252                    })
253                    .collect();
254                let failed: Vec<_> = results.iter().filter(|r| !r.success).collect();
255
256                if args.common.dry_run {
257                    println!("\nRollback verification complete:");
258                    let can_rollback = results.iter().filter(|r| r.success).count();
259                    println!("  {can_rollback} package(s) can be rolled back");
260                    if !already_original.is_empty() {
261                        println!(
262                            "  {} package(s) already in original state",
263                            already_original.len()
264                        );
265                    }
266                    if !failed.is_empty() {
267                        println!("  {} package(s) cannot be rolled back", failed.len());
268                    }
269                } else {
270                    if !rolled_back.is_empty() || !already_original.is_empty() {
271                        println!("\nRolled back packages:");
272                        for result in &rolled_back {
273                            println!("  {}", result.package_key);
274                        }
275                        for result in &already_original {
276                            println!("  {} (already original)", result.package_key);
277                        }
278                    }
279                    if !failed.is_empty() {
280                        println!("\nFailed to rollback:");
281                        for result in &failed {
282                            println!(
283                                "  {}: {}",
284                                result.package_key,
285                                result.error.as_deref().unwrap_or("unknown error")
286                            );
287                        }
288                    }
289                }
290
291                if args.common.verbose {
292                    println!("\nDetailed verification:");
293                    for result in &results {
294                        println!("  {}:", result.package_key);
295                        for f in &result.files_verified {
296                            let status_str = match f.status {
297                                VerifyRollbackStatus::Ready => "ready",
298                                VerifyRollbackStatus::AlreadyOriginal => "already original",
299                                VerifyRollbackStatus::HashMismatch => "hash mismatch",
300                                VerifyRollbackStatus::NotFound => "not found",
301                                VerifyRollbackStatus::MissingBlob => "missing blob",
302                            };
303                            println!("    {} [{}]", f.file, status_str);
304                            if let Some(ref msg) = f.message {
305                                println!("      message: {msg}");
306                            }
307                            if let Some(ref h) = f.current_hash {
308                                println!("      current:  {h}");
309                            }
310                            if let Some(ref h) = f.expected_hash {
311                                println!("      expected: {h}");
312                            }
313                            if let Some(ref h) = f.target_hash {
314                                println!("      target:   {h}");
315                            }
316                        }
317                    }
318                }
319            }
320
321            if success {
322                track_patch_rolled_back(rolled_back_count, api_token.as_deref(), org_slug.as_deref()).await;
323            } else {
324                track_patch_rollback_failed("One or more rollbacks failed", api_token.as_deref(), org_slug.as_deref()).await;
325            }
326
327            if success { 0 } else { 1 }
328        }
329        Err(e) => {
330            track_patch_rollback_failed(&e, api_token.as_deref(), org_slug.as_deref()).await;
331            if args.common.json {
332                println!("{}", serde_json::to_string_pretty(&serde_json::json!({
333                    "status": "error",
334                    "error": e,
335                    "rolledBack": 0,
336                    "alreadyOriginal": 0,
337                    "failed": 0,
338                    "dryRun": args.common.dry_run,
339                    "results": [],
340                })).unwrap());
341            } else if !args.common.silent {
342                eprintln!("Error: {e}");
343            }
344            1
345        }
346    }
347}
348
349async fn rollback_patches_inner(
350    args: &RollbackArgs,
351    manifest_path: &Path,
352) -> Result<(bool, Vec<RollbackResult>), String> {
353    let manifest = read_manifest(manifest_path)
354        .await
355        .map_err(|e| e.to_string())?
356        .ok_or_else(|| "Invalid manifest".to_string())?;
357
358    let socket_dir = manifest_path.parent().unwrap();
359    let blobs_path = socket_dir.join("blobs");
360    tokio::fs::create_dir_all(&blobs_path)
361        .await
362        .map_err(|e| e.to_string())?;
363
364    let patches_to_rollback =
365        find_patches_to_rollback(&manifest, args.identifier.as_deref());
366
367    if patches_to_rollback.is_empty() {
368        if args.identifier.is_some() {
369            return Err(format!(
370                "No patch found matching identifier: {}",
371                args.identifier.as_deref().unwrap()
372            ));
373        }
374        if !args.common.silent && !args.common.json {
375            println!("No patches found in manifest");
376        }
377        return Ok((true, Vec::new()));
378    }
379
380    // Create filtered manifest
381    let filtered_manifest = PatchManifest {
382        patches: patches_to_rollback
383            .iter()
384            .map(|p| (p.purl.clone(), p.patch.clone()))
385            .collect(),
386    };
387
388    // Check for missing beforeHash blobs
389    let missing_blobs = get_missing_before_blobs(&filtered_manifest, &blobs_path).await;
390    if !missing_blobs.is_empty() {
391        if args.common.offline {
392            if !args.common.silent && !args.common.json {
393                eprintln!(
394                    "Error: {} blob(s) are missing and --offline mode is enabled.",
395                    missing_blobs.len()
396                );
397                eprintln!("Run \"socket-patch repair\" to download missing blobs.");
398            }
399            return Ok((false, Vec::new()));
400        }
401
402        if !args.common.silent && !args.common.json {
403            println!("Downloading {} missing blob(s)...", missing_blobs.len());
404        }
405
406        let (client, _) =
407            get_api_client_with_overrides(args.common.api_client_overrides()).await;
408        let fetch_result = fetch_blobs_by_hash(&missing_blobs, &blobs_path, &client, None).await;
409
410        if !args.common.silent && !args.common.json {
411            println!("{}", format_fetch_result(&fetch_result));
412        }
413
414        let still_missing = get_missing_before_blobs(&filtered_manifest, &blobs_path).await;
415        if !still_missing.is_empty() {
416            if !args.common.silent && !args.common.json {
417                eprintln!(
418                    "{} blob(s) could not be downloaded. Cannot rollback.",
419                    still_missing.len()
420                );
421            }
422            return Ok((false, Vec::new()));
423        }
424    }
425
426    // Partition PURLs by ecosystem
427    let rollback_purls: Vec<String> = patches_to_rollback.iter().map(|p| p.purl.clone()).collect();
428    let partitioned =
429        partition_purls(&rollback_purls, args.common.ecosystems.as_deref());
430
431    let crawler_options = CrawlerOptions {
432        cwd: args.common.cwd.clone(),
433        global: args.common.global,
434        global_prefix: args.common.global_prefix.clone(),
435        batch_size: 100,
436    };
437
438    let all_packages =
439        find_packages_for_rollback(&partitioned, &crawler_options, args.common.silent || args.common.json).await;
440
441    if all_packages.is_empty() {
442        if !args.common.silent && !args.common.json {
443            println!("No packages found that match patches to rollback");
444        }
445        return Ok((true, Vec::new()));
446    }
447
448    // Rollback patches
449    let mut results: Vec<RollbackResult> = Vec::new();
450    let mut has_errors = false;
451
452    for (purl, pkg_path) in &all_packages {
453        let patch = match filtered_manifest.patches.get(purl) {
454            Some(p) => p,
455            None => continue,
456        };
457
458        let result = rollback_package_patch(
459            purl,
460            pkg_path,
461            &patch.files,
462            &blobs_path,
463            args.common.dry_run,
464        )
465        .await;
466
467        if !result.success {
468            has_errors = true;
469            if !args.common.silent && !args.common.json {
470                eprintln!(
471                    "Failed to rollback {}: {}",
472                    purl,
473                    result.error.as_deref().unwrap_or("unknown error")
474                );
475            }
476        }
477        results.push(result);
478    }
479
480    Ok((!has_errors, results))
481}
482
483// Export for use by remove command
484#[allow(clippy::too_many_arguments)]
485pub async fn rollback_patches(
486    cwd: &Path,
487    manifest_path: &Path,
488    identifier: Option<&str>,
489    dry_run: bool,
490    silent: bool,
491    offline: bool,
492    global: bool,
493    global_prefix: Option<PathBuf>,
494    ecosystems: Option<Vec<String>>,
495) -> Result<(bool, Vec<RollbackResult>), String> {
496    let args = RollbackArgs {
497        identifier: identifier.map(String::from),
498        common: crate::args::GlobalArgs {
499            cwd: cwd.to_path_buf(),
500            manifest_path: manifest_path.display().to_string(),
501            offline,
502            global,
503            global_prefix,
504            ecosystems,
505            silent,
506            dry_run,
507            ..crate::args::GlobalArgs::default()
508        },
509        one_off: false,
510    };
511    rollback_patches_inner(&args, manifest_path).await
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord};
518    use std::collections::HashMap;
519
520    fn make_record(uuid: &str) -> PatchRecord {
521        PatchRecord {
522            uuid: uuid.to_string(),
523            exported_at: "2024-01-01T00:00:00Z".to_string(),
524            files: HashMap::new(),
525            vulnerabilities: HashMap::new(),
526            description: "test patch".to_string(),
527            license: "MIT".to_string(),
528            tier: "free".to_string(),
529        }
530    }
531
532    fn make_manifest() -> PatchManifest {
533        let mut patches = HashMap::new();
534        patches.insert("pkg:npm/foo@1.0".to_string(), make_record("uuid-foo"));
535        patches.insert("pkg:npm/bar@2.0".to_string(), make_record("uuid-bar"));
536        patches.insert("pkg:pypi/baz@3.0".to_string(), make_record("uuid-baz"));
537        PatchManifest { patches }
538    }
539
540    #[test]
541    fn test_find_patches_to_rollback_none_returns_all() {
542        let manifest = make_manifest();
543        let result = find_patches_to_rollback(&manifest, None);
544        assert_eq!(result.len(), 3);
545    }
546
547    #[test]
548    fn test_find_patches_to_rollback_purl_match() {
549        let manifest = make_manifest();
550        let result =
551            find_patches_to_rollback(&manifest, Some("pkg:npm/foo@1.0"));
552        assert_eq!(result.len(), 1);
553        assert_eq!(result[0].purl, "pkg:npm/foo@1.0");
554    }
555
556    #[test]
557    fn test_find_patches_to_rollback_purl_no_match() {
558        let manifest = make_manifest();
559        let result =
560            find_patches_to_rollback(&manifest, Some("pkg:npm/nonexistent@1"));
561        assert!(result.is_empty());
562    }
563
564    #[test]
565    fn test_find_patches_to_rollback_uuid_match() {
566        let manifest = make_manifest();
567        let result = find_patches_to_rollback(&manifest, Some("uuid-bar"));
568        assert_eq!(result.len(), 1);
569        assert_eq!(result[0].patch.uuid, "uuid-bar");
570        assert_eq!(result[0].purl, "pkg:npm/bar@2.0");
571    }
572
573    #[test]
574    fn test_find_patches_to_rollback_uuid_no_match() {
575        let manifest = make_manifest();
576        let result =
577            find_patches_to_rollback(&manifest, Some("uuid-does-not-exist"));
578        assert!(result.is_empty());
579    }
580}