1use clap::Args;
2use socket_patch_core::api::blob_fetcher::{
3 fetch_blobs_by_hash, format_fetch_result,
4};
5use socket_patch_core::api::client::get_api_client_with_overrides;
6use socket_patch_core::crawlers::CrawlerOptions;
7use socket_patch_core::manifest::operations::read_manifest;
8use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord};
9use socket_patch_core::patch::rollback::{rollback_package_patch, RollbackResult, VerifyRollbackStatus};
10use socket_patch_core::utils::telemetry::{track_patch_rolled_back, track_patch_rollback_failed};
11use std::collections::HashSet;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14
15use crate::args::{apply_env_toggles, GlobalArgs};
16use crate::commands::lock_cli::{acquire_or_emit, LOCK_BROKEN_CODE};
17use crate::ecosystem_dispatch::{find_packages_for_rollback, partition_purls};
18use crate::json_envelope::Command as EnvelopeCommand;
19
20#[derive(Args)]
21pub struct RollbackArgs {
22 pub identifier: Option<String>,
24
25 #[command(flatten)]
26 pub common: GlobalArgs,
27
28 #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
30 pub one_off: bool,
31}
32
33struct PatchToRollback {
34 purl: String,
35 patch: PatchRecord,
36}
37
38fn find_patches_to_rollback(
39 manifest: &PatchManifest,
40 identifier: Option<&str>,
41) -> Vec<PatchToRollback> {
42 match identifier {
43 None => manifest
44 .patches
45 .iter()
46 .map(|(purl, patch)| PatchToRollback {
47 purl: purl.clone(),
48 patch: patch.clone(),
49 })
50 .collect(),
51 Some(id) => {
52 let mut patches = Vec::new();
53 if id.starts_with("pkg:") {
54 if let Some(patch) = manifest.patches.get(id) {
55 patches.push(PatchToRollback {
56 purl: id.to_string(),
57 patch: patch.clone(),
58 });
59 }
60 } else {
61 for (purl, patch) in &manifest.patches {
62 if patch.uuid == id {
63 patches.push(PatchToRollback {
64 purl: purl.clone(),
65 patch: patch.clone(),
66 });
67 }
68 }
69 }
70 patches
71 }
72 }
73}
74
75fn get_before_hash_blobs(manifest: &PatchManifest) -> HashSet<String> {
76 let mut blobs = HashSet::new();
77 for patch in manifest.patches.values() {
78 for file_info in patch.files.values() {
79 blobs.insert(file_info.before_hash.clone());
80 }
81 }
82 blobs
83}
84
85async fn get_missing_before_blobs(
86 manifest: &PatchManifest,
87 blobs_path: &Path,
88) -> HashSet<String> {
89 let before_blobs = get_before_hash_blobs(manifest);
90 let mut missing = HashSet::new();
91 for hash in before_blobs {
92 let blob_path = blobs_path.join(&hash);
93 if tokio::fs::metadata(&blob_path).await.is_err() {
94 missing.insert(hash);
95 }
96 }
97 missing
98}
99
100fn verify_rollback_status_str(status: &VerifyRollbackStatus) -> &'static str {
101 match status {
102 VerifyRollbackStatus::Ready => "ready",
103 VerifyRollbackStatus::AlreadyOriginal => "already_original",
104 VerifyRollbackStatus::HashMismatch => "hash_mismatch",
105 VerifyRollbackStatus::NotFound => "not_found",
106 VerifyRollbackStatus::MissingBlob => "missing_blob",
107 }
108}
109
110fn result_to_json(result: &RollbackResult) -> serde_json::Value {
111 serde_json::json!({
112 "purl": result.package_key,
113 "path": result.package_path,
114 "success": result.success,
115 "error": result.error,
116 "filesRolledBack": result.files_rolled_back,
117 "filesVerified": result.files_verified.iter().map(|f| {
118 serde_json::json!({
119 "file": f.file,
120 "status": verify_rollback_status_str(&f.status),
121 "message": f.message,
122 "currentHash": f.current_hash,
123 "expectedHash": f.expected_hash,
124 "targetHash": f.target_hash,
125 })
126 }).collect::<Vec<_>>(),
127 })
128}
129
130pub async fn run(args: RollbackArgs) -> i32 {
131 apply_env_toggles(&args.common);
132
133 let (telemetry_client, _) =
134 get_api_client_with_overrides(args.common.api_client_overrides()).await;
135 let api_token = telemetry_client.api_token().cloned();
136 let org_slug = telemetry_client.org_slug().cloned();
137
138 if args.one_off && args.identifier.is_none() {
140 if args.common.json {
141 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
142 "status": "error",
143 "error": "--one-off requires an identifier (UUID or PURL)",
144 })).unwrap());
145 } else {
146 eprintln!("Error: --one-off requires an identifier (UUID or PURL)");
147 }
148 return 1;
149 }
150
151 if args.one_off {
153 if args.common.json {
154 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
155 "status": "error",
156 "error": "One-off rollback mode is not yet implemented",
157 })).unwrap());
158 } else {
159 eprintln!("One-off rollback mode: fetching patch data...");
160 }
161 return 1;
162 }
163
164 let manifest_path = args.common.resolved_manifest_path();
165
166 if tokio::fs::metadata(&manifest_path).await.is_err() {
167 if args.common.json {
168 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
169 "status": "error",
170 "error": "Manifest not found",
171 "path": manifest_path.display().to_string(),
172 })).unwrap());
173 } else if !args.common.silent {
174 eprintln!("Manifest not found at {}", manifest_path.display());
175 }
176 return 1;
177 }
178
179 let socket_dir = manifest_path.parent().unwrap_or(Path::new("."));
183 let acquired = match acquire_or_emit(
184 socket_dir,
185 EnvelopeCommand::Rollback,
186 args.common.json,
187 args.common.silent,
188 args.common.dry_run,
189 Duration::from_secs(args.common.lock_timeout.unwrap_or(0)),
190 args.common.break_lock,
191 ) {
192 Ok(acquired) => acquired,
193 Err(code) => return code,
194 };
195 let _lock = acquired.guard;
196 let lock_was_broken = acquired.broke_lock;
197
198 match rollback_patches_inner(&args, &manifest_path).await {
199 Ok((success, results)) => {
200 let rolled_back_count = results
201 .iter()
202 .filter(|r| r.success && !r.files_rolled_back.is_empty())
203 .count();
204 let already_original_count = results
205 .iter()
206 .filter(|r| {
207 r.success
208 && r.files_verified.iter().all(|f| {
209 f.status == VerifyRollbackStatus::AlreadyOriginal
210 })
211 })
212 .count();
213 let failed_count = results.iter().filter(|r| !r.success).count();
214
215 if args.common.json {
216 let mut warnings = Vec::new();
222 if lock_was_broken {
223 warnings.push(serde_json::json!({
224 "code": LOCK_BROKEN_CODE,
225 "message": format!(
226 "--break-lock removed {}/apply.lock before acquisition",
227 socket_dir.display()
228 ),
229 }));
230 }
231 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
232 "status": if success { "success" } else { "partial_failure" },
233 "rolledBack": rolled_back_count,
234 "alreadyOriginal": already_original_count,
235 "failed": failed_count,
236 "dryRun": args.common.dry_run,
237 "warnings": warnings,
238 "results": results.iter().map(result_to_json).collect::<Vec<_>>(),
239 })).unwrap());
240 } else if !args.common.silent && !results.is_empty() {
241 let rolled_back: Vec<_> = results
242 .iter()
243 .filter(|r| r.success && !r.files_rolled_back.is_empty())
244 .collect();
245 let already_original: Vec<_> = results
246 .iter()
247 .filter(|r| {
248 r.success
249 && r.files_verified.iter().all(|f| {
250 f.status == VerifyRollbackStatus::AlreadyOriginal
251 })
252 })
253 .collect();
254 let failed: Vec<_> = results.iter().filter(|r| !r.success).collect();
255
256 if args.common.dry_run {
257 println!("\nRollback verification complete:");
258 let can_rollback = results.iter().filter(|r| r.success).count();
259 println!(" {can_rollback} package(s) can be rolled back");
260 if !already_original.is_empty() {
261 println!(
262 " {} package(s) already in original state",
263 already_original.len()
264 );
265 }
266 if !failed.is_empty() {
267 println!(" {} package(s) cannot be rolled back", failed.len());
268 }
269 } else {
270 if !rolled_back.is_empty() || !already_original.is_empty() {
271 println!("\nRolled back packages:");
272 for result in &rolled_back {
273 println!(" {}", result.package_key);
274 }
275 for result in &already_original {
276 println!(" {} (already original)", result.package_key);
277 }
278 }
279 if !failed.is_empty() {
280 println!("\nFailed to rollback:");
281 for result in &failed {
282 println!(
283 " {}: {}",
284 result.package_key,
285 result.error.as_deref().unwrap_or("unknown error")
286 );
287 }
288 }
289 }
290
291 if args.common.verbose {
292 println!("\nDetailed verification:");
293 for result in &results {
294 println!(" {}:", result.package_key);
295 for f in &result.files_verified {
296 let status_str = match f.status {
297 VerifyRollbackStatus::Ready => "ready",
298 VerifyRollbackStatus::AlreadyOriginal => "already original",
299 VerifyRollbackStatus::HashMismatch => "hash mismatch",
300 VerifyRollbackStatus::NotFound => "not found",
301 VerifyRollbackStatus::MissingBlob => "missing blob",
302 };
303 println!(" {} [{}]", f.file, status_str);
304 if let Some(ref msg) = f.message {
305 println!(" message: {msg}");
306 }
307 if let Some(ref h) = f.current_hash {
308 println!(" current: {h}");
309 }
310 if let Some(ref h) = f.expected_hash {
311 println!(" expected: {h}");
312 }
313 if let Some(ref h) = f.target_hash {
314 println!(" target: {h}");
315 }
316 }
317 }
318 }
319 }
320
321 if success {
322 track_patch_rolled_back(rolled_back_count, api_token.as_deref(), org_slug.as_deref()).await;
323 } else {
324 track_patch_rollback_failed("One or more rollbacks failed", api_token.as_deref(), org_slug.as_deref()).await;
325 }
326
327 if success { 0 } else { 1 }
328 }
329 Err(e) => {
330 track_patch_rollback_failed(&e, api_token.as_deref(), org_slug.as_deref()).await;
331 if args.common.json {
332 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
333 "status": "error",
334 "error": e,
335 "rolledBack": 0,
336 "alreadyOriginal": 0,
337 "failed": 0,
338 "dryRun": args.common.dry_run,
339 "results": [],
340 })).unwrap());
341 } else if !args.common.silent {
342 eprintln!("Error: {e}");
343 }
344 1
345 }
346 }
347}
348
349async fn rollback_patches_inner(
350 args: &RollbackArgs,
351 manifest_path: &Path,
352) -> Result<(bool, Vec<RollbackResult>), String> {
353 let manifest = read_manifest(manifest_path)
354 .await
355 .map_err(|e| e.to_string())?
356 .ok_or_else(|| "Invalid manifest".to_string())?;
357
358 let socket_dir = manifest_path.parent().unwrap();
359 let blobs_path = socket_dir.join("blobs");
360 tokio::fs::create_dir_all(&blobs_path)
361 .await
362 .map_err(|e| e.to_string())?;
363
364 let patches_to_rollback =
365 find_patches_to_rollback(&manifest, args.identifier.as_deref());
366
367 if patches_to_rollback.is_empty() {
368 if args.identifier.is_some() {
369 return Err(format!(
370 "No patch found matching identifier: {}",
371 args.identifier.as_deref().unwrap()
372 ));
373 }
374 if !args.common.silent && !args.common.json {
375 println!("No patches found in manifest");
376 }
377 return Ok((true, Vec::new()));
378 }
379
380 let filtered_manifest = PatchManifest {
382 patches: patches_to_rollback
383 .iter()
384 .map(|p| (p.purl.clone(), p.patch.clone()))
385 .collect(),
386 };
387
388 let missing_blobs = get_missing_before_blobs(&filtered_manifest, &blobs_path).await;
390 if !missing_blobs.is_empty() {
391 if args.common.offline {
392 if !args.common.silent && !args.common.json {
393 eprintln!(
394 "Error: {} blob(s) are missing and --offline mode is enabled.",
395 missing_blobs.len()
396 );
397 eprintln!("Run \"socket-patch repair\" to download missing blobs.");
398 }
399 return Ok((false, Vec::new()));
400 }
401
402 if !args.common.silent && !args.common.json {
403 println!("Downloading {} missing blob(s)...", missing_blobs.len());
404 }
405
406 let (client, _) =
407 get_api_client_with_overrides(args.common.api_client_overrides()).await;
408 let fetch_result = fetch_blobs_by_hash(&missing_blobs, &blobs_path, &client, None).await;
409
410 if !args.common.silent && !args.common.json {
411 println!("{}", format_fetch_result(&fetch_result));
412 }
413
414 let still_missing = get_missing_before_blobs(&filtered_manifest, &blobs_path).await;
415 if !still_missing.is_empty() {
416 if !args.common.silent && !args.common.json {
417 eprintln!(
418 "{} blob(s) could not be downloaded. Cannot rollback.",
419 still_missing.len()
420 );
421 }
422 return Ok((false, Vec::new()));
423 }
424 }
425
426 let rollback_purls: Vec<String> = patches_to_rollback.iter().map(|p| p.purl.clone()).collect();
428 let partitioned =
429 partition_purls(&rollback_purls, args.common.ecosystems.as_deref());
430
431 let crawler_options = CrawlerOptions {
432 cwd: args.common.cwd.clone(),
433 global: args.common.global,
434 global_prefix: args.common.global_prefix.clone(),
435 batch_size: 100,
436 };
437
438 let all_packages =
439 find_packages_for_rollback(&partitioned, &crawler_options, args.common.silent || args.common.json).await;
440
441 if all_packages.is_empty() {
442 if !args.common.silent && !args.common.json {
443 println!("No packages found that match patches to rollback");
444 }
445 return Ok((true, Vec::new()));
446 }
447
448 let mut results: Vec<RollbackResult> = Vec::new();
450 let mut has_errors = false;
451
452 for (purl, pkg_path) in &all_packages {
453 let patch = match filtered_manifest.patches.get(purl) {
454 Some(p) => p,
455 None => continue,
456 };
457
458 let result = rollback_package_patch(
459 purl,
460 pkg_path,
461 &patch.files,
462 &blobs_path,
463 args.common.dry_run,
464 )
465 .await;
466
467 if !result.success {
468 has_errors = true;
469 if !args.common.silent && !args.common.json {
470 eprintln!(
471 "Failed to rollback {}: {}",
472 purl,
473 result.error.as_deref().unwrap_or("unknown error")
474 );
475 }
476 }
477 results.push(result);
478 }
479
480 Ok((!has_errors, results))
481}
482
483#[allow(clippy::too_many_arguments)]
485pub async fn rollback_patches(
486 cwd: &Path,
487 manifest_path: &Path,
488 identifier: Option<&str>,
489 dry_run: bool,
490 silent: bool,
491 offline: bool,
492 global: bool,
493 global_prefix: Option<PathBuf>,
494 ecosystems: Option<Vec<String>>,
495) -> Result<(bool, Vec<RollbackResult>), String> {
496 let args = RollbackArgs {
497 identifier: identifier.map(String::from),
498 common: crate::args::GlobalArgs {
499 cwd: cwd.to_path_buf(),
500 manifest_path: manifest_path.display().to_string(),
501 offline,
502 global,
503 global_prefix,
504 ecosystems,
505 silent,
506 dry_run,
507 ..crate::args::GlobalArgs::default()
508 },
509 one_off: false,
510 };
511 rollback_patches_inner(&args, manifest_path).await
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord};
518 use std::collections::HashMap;
519
520 fn make_record(uuid: &str) -> PatchRecord {
521 PatchRecord {
522 uuid: uuid.to_string(),
523 exported_at: "2024-01-01T00:00:00Z".to_string(),
524 files: HashMap::new(),
525 vulnerabilities: HashMap::new(),
526 description: "test patch".to_string(),
527 license: "MIT".to_string(),
528 tier: "free".to_string(),
529 }
530 }
531
532 fn make_manifest() -> PatchManifest {
533 let mut patches = HashMap::new();
534 patches.insert("pkg:npm/foo@1.0".to_string(), make_record("uuid-foo"));
535 patches.insert("pkg:npm/bar@2.0".to_string(), make_record("uuid-bar"));
536 patches.insert("pkg:pypi/baz@3.0".to_string(), make_record("uuid-baz"));
537 PatchManifest { patches }
538 }
539
540 #[test]
541 fn test_find_patches_to_rollback_none_returns_all() {
542 let manifest = make_manifest();
543 let result = find_patches_to_rollback(&manifest, None);
544 assert_eq!(result.len(), 3);
545 }
546
547 #[test]
548 fn test_find_patches_to_rollback_purl_match() {
549 let manifest = make_manifest();
550 let result =
551 find_patches_to_rollback(&manifest, Some("pkg:npm/foo@1.0"));
552 assert_eq!(result.len(), 1);
553 assert_eq!(result[0].purl, "pkg:npm/foo@1.0");
554 }
555
556 #[test]
557 fn test_find_patches_to_rollback_purl_no_match() {
558 let manifest = make_manifest();
559 let result =
560 find_patches_to_rollback(&manifest, Some("pkg:npm/nonexistent@1"));
561 assert!(result.is_empty());
562 }
563
564 #[test]
565 fn test_find_patches_to_rollback_uuid_match() {
566 let manifest = make_manifest();
567 let result = find_patches_to_rollback(&manifest, Some("uuid-bar"));
568 assert_eq!(result.len(), 1);
569 assert_eq!(result[0].patch.uuid, "uuid-bar");
570 assert_eq!(result[0].purl, "pkg:npm/bar@2.0");
571 }
572
573 #[test]
574 fn test_find_patches_to_rollback_uuid_no_match() {
575 let manifest = make_manifest();
576 let result =
577 find_patches_to_rollback(&manifest, Some("uuid-does-not-exist"));
578 assert!(result.is_empty());
579 }
580}