1use clap::Args;
2use regex::Regex;
3use socket_patch_core::api::client::{
4 build_proxy_fallback_client, get_api_client_with_overrides, is_fallback_candidate,
5};
6use socket_patch_core::api::types::{
7 PatchResponse, PatchSearchResult, SearchResponse, VulnerabilityResponse,
8};
9use socket_patch_core::crawlers::CrawlerOptions;
10use socket_patch_core::manifest::operations::{read_manifest, write_manifest};
11use socket_patch_core::manifest::schema::{
12 PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo,
13};
14use socket_patch_core::utils::fuzzy_match::fuzzy_match_packages;
15use socket_patch_core::utils::purl::is_purl;
16use socket_patch_core::utils::telemetry::{track_patch_fetch_failed, track_patch_fetched};
17use std::collections::HashMap;
18use std::fmt;
19use std::path::PathBuf;
20
21use crate::args::{apply_env_toggles, GlobalArgs};
22use crate::ecosystem_dispatch::crawl_all_ecosystems;
23use crate::output::{confirm, select_one, SelectError};
24
25fn ecosystem_from_purl(purl: &str) -> String {
30 purl.strip_prefix("pkg:")
31 .and_then(|rest| rest.split('/').next())
32 .unwrap_or("")
33 .to_string()
34}
35
36#[derive(Debug, PartialEq, Eq, Clone)]
40pub(crate) enum PatchAction {
41 Added,
43 Updated { old_uuid: String },
46 Skipped,
48}
49
50pub(crate) fn decide_patch_action(
53 manifest: &PatchManifest,
54 purl: &str,
55 new_uuid: &str,
56) -> PatchAction {
57 match manifest.patches.get(purl) {
58 Some(existing) if existing.uuid == new_uuid => PatchAction::Skipped,
59 Some(existing) => PatchAction::Updated {
60 old_uuid: existing.uuid.clone(),
61 },
62 None => PatchAction::Added,
63 }
64}
65
66pub(crate) fn severity_rank(severity: &str) -> u8 {
70 match severity.to_ascii_lowercase().as_str() {
71 "critical" => 4,
72 "high" => 3,
73 "moderate" | "medium" => 2,
75 "low" => 1,
76 _ => 0,
77 }
78}
79
80pub(crate) fn max_vuln_severity(
84 vulns: &HashMap<String, VulnerabilityResponse>,
85) -> Option<String> {
86 vulns
87 .values()
88 .max_by_key(|v| severity_rank(&v.severity))
89 .map(|v| v.severity.clone())
90}
91
92pub(crate) fn patch_event_metadata(patch: &PatchResponse) -> serde_json::Value {
103 let mut vulns: Vec<serde_json::Value> = patch
104 .vulnerabilities
105 .iter()
106 .map(|(id, v)| {
107 serde_json::json!({
108 "id": id,
109 "cves": v.cves,
110 "severity": v.severity,
111 "summary": v.summary,
112 "description": v.description,
113 })
114 })
115 .collect();
116 vulns.sort_by(|a, b| {
119 a["id"]
120 .as_str()
121 .unwrap_or("")
122 .cmp(b["id"].as_str().unwrap_or(""))
123 });
124
125 let mut meta = serde_json::Map::new();
126 meta.insert(
127 "description".into(),
128 serde_json::Value::String(patch.description.clone()),
129 );
130 meta.insert(
131 "license".into(),
132 serde_json::Value::String(patch.license.clone()),
133 );
134 meta.insert(
135 "tier".into(),
136 serde_json::Value::String(patch.tier.clone()),
137 );
138 meta.insert(
139 "exportedAt".into(),
140 serde_json::Value::String(patch.published_at.clone()),
141 );
142 if let Some(sev) = max_vuln_severity(&patch.vulnerabilities) {
143 meta.insert("severity".into(), serde_json::Value::String(sev));
144 }
145 meta.insert("vulnerabilities".into(), serde_json::Value::Array(vulns));
146 serde_json::Value::Object(meta)
147}
148
149fn merge_metadata(record: &mut serde_json::Value, meta: serde_json::Value) {
153 if let (Some(record_obj), serde_json::Value::Object(meta_obj)) =
154 (record.as_object_mut(), meta)
155 {
156 for (k, v) in meta_obj {
157 record_obj.insert(k, v);
158 }
159 }
160}
161
162#[derive(Args)]
163pub struct GetArgs {
164 pub identifier: String,
166
167 #[command(flatten)]
168 pub common: GlobalArgs,
169
170 #[arg(long, default_value_t = false)]
172 pub id: bool,
173
174 #[arg(long, default_value_t = false)]
176 pub cve: bool,
177
178 #[arg(long, default_value_t = false)]
180 pub ghsa: bool,
181
182 #[arg(short = 'p', long = "package", default_value_t = false)]
184 pub package: bool,
185
186 #[arg(long = "save-only", alias = "no-apply", env = "SOCKET_SAVE_ONLY", default_value_t = false)]
188 pub save_only: bool,
189
190 #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
192 pub one_off: bool,
193}
194
195#[derive(Debug, PartialEq)]
196enum IdentifierType {
197 Uuid,
198 Cve,
199 Ghsa,
200 Purl,
201 Package,
202}
203
204impl fmt::Display for IdentifierType {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 match self {
207 IdentifierType::Uuid => write!(f, "UUID"),
208 IdentifierType::Cve => write!(f, "CVE"),
209 IdentifierType::Ghsa => write!(f, "GHSA"),
210 IdentifierType::Purl => write!(f, "PURL"),
211 IdentifierType::Package => write!(f, "package name"),
212 }
213 }
214}
215
216fn detect_identifier_type(identifier: &str) -> Option<IdentifierType> {
217 let uuid_re = Regex::new(r"(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$").unwrap();
218 let cve_re = Regex::new(r"(?i)^CVE-\d{4}-\d+$").unwrap();
219 let ghsa_re = Regex::new(r"(?i)^GHSA-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}$").unwrap();
220
221 if uuid_re.is_match(identifier) {
222 Some(IdentifierType::Uuid)
223 } else if cve_re.is_match(identifier) {
224 Some(IdentifierType::Cve)
225 } else if ghsa_re.is_match(identifier) {
226 Some(IdentifierType::Ghsa)
227 } else if is_purl(identifier) {
228 Some(IdentifierType::Purl)
229 } else {
230 None
231 }
232}
233
234pub fn select_patches(
243 patches: &[PatchSearchResult],
244 can_access_paid: bool,
245 is_json: bool,
246) -> Result<Vec<PatchSearchResult>, i32> {
247 let mut by_purl: HashMap<String, Vec<&PatchSearchResult>> = HashMap::new();
249 for p in patches {
250 if p.tier == "free" || can_access_paid {
251 by_purl.entry(p.purl.clone()).or_default().push(p);
252 }
253 }
254
255 let mut selected = Vec::new();
256
257 for (purl, mut group) in by_purl {
258 group.sort_by(|a, b| b.published_at.cmp(&a.published_at));
260
261 if can_access_paid {
262 let choice = group
264 .iter()
265 .find(|p| p.tier == "paid")
266 .or_else(|| group.first())
267 .unwrap();
268 selected.push((*choice).clone());
269 } else if group.len() == 1 {
270 selected.push(group[0].clone());
271 } else {
272 let options: Vec<String> = group
274 .iter()
275 .map(|p| {
276 let vuln_summary: Vec<String> = p
277 .vulnerabilities
278 .iter()
279 .map(|(id, v)| {
280 if v.cves.is_empty() {
281 id.clone()
282 } else {
283 v.cves.join(", ")
284 }
285 })
286 .collect();
287 let vulns = if vuln_summary.is_empty() {
288 String::new()
289 } else {
290 format!(" (fixes: {})", vuln_summary.join(", "))
291 };
292 let desc = if p.description.len() > 60 {
293 format!("{}...", &p.description[..57])
294 } else {
295 p.description.clone()
296 };
297 format!("{} [{}]{} - {}", p.uuid, p.tier, vulns, desc)
298 })
299 .collect();
300
301 match select_one(
302 &format!("Multiple patches available for {purl}. Select one:"),
303 &options,
304 is_json,
305 ) {
306 Ok(idx) => {
307 selected.push(group[idx].clone());
308 }
309 Err(SelectError::JsonModeNeedsExplicit) => {
310 let options_json: Vec<serde_json::Value> = group
311 .iter()
312 .map(|p| {
313 let vulns: Vec<serde_json::Value> = p
314 .vulnerabilities
315 .iter()
316 .map(|(id, v)| {
317 serde_json::json!({
318 "id": id,
319 "cves": v.cves,
320 "severity": v.severity,
321 "summary": v.summary,
322 })
323 })
324 .collect();
325 serde_json::json!({
326 "uuid": p.uuid,
327 "tier": p.tier,
328 "published_at": p.published_at,
329 "description": p.description,
330 "vulnerabilities": vulns,
331 })
332 })
333 .collect();
334 println!(
335 "{}",
336 serde_json::to_string_pretty(&serde_json::json!({
337 "status": "selection_required",
338 "error": format!("Multiple patches available for {purl}. Specify --id <UUID> to select one."),
339 "purl": purl,
340 "options": options_json,
341 }))
342 .unwrap()
343 );
344 return Err(1);
345 }
346 Err(SelectError::Cancelled) => {
347 eprintln!("Selection cancelled.");
348 return Err(0);
349 }
350 }
351 }
352 }
353
354 Ok(selected)
355}
356
357pub struct DownloadParams {
359 pub cwd: PathBuf,
360 pub org: Option<String>,
361 pub save_only: bool,
362 pub one_off: bool,
363 pub global: bool,
364 pub global_prefix: Option<PathBuf>,
365 pub json: bool,
366 pub silent: bool,
367 pub download_mode: String,
369 pub api_overrides: socket_patch_core::api::client::ApiClientEnvOverrides,
374}
375
376pub async fn download_and_apply_patches(
380 selected: &[PatchSearchResult],
381 params: &DownloadParams,
382) -> (i32, serde_json::Value) {
383 let mut overrides = params.api_overrides.clone();
384 if overrides.org_slug.is_none() {
385 overrides.org_slug = params.org.clone();
386 }
387 let (api_client, _) =
388 socket_patch_core::api::client::get_api_client_with_overrides(overrides).await;
389 let effective_org: Option<&str> = None;
390
391 let socket_dir = params.cwd.join(".socket");
392 let blobs_dir = socket_dir.join("blobs");
393 let manifest_path = socket_dir.join("manifest.json");
394
395 if let Err(e) = tokio::fs::create_dir_all(&socket_dir).await {
396 let err = format!("Failed to create .socket directory: {}", e);
397 if params.json {
398 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
399 "status": "error",
400 "error": &err,
401 })).unwrap());
402 } else {
403 eprintln!("Error: {}", &err);
404 }
405 return (1, serde_json::json!({"status": "error", "error": err}));
406 }
407 if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await {
408 let err = format!("Failed to create blobs directory: {}", e);
409 if params.json {
410 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
411 "status": "error",
412 "error": &err,
413 })).unwrap());
414 } else {
415 eprintln!("Error: {}", &err);
416 }
417 return (1, serde_json::json!({"status": "error", "error": err}));
418 }
419
420 let mut manifest = match read_manifest(&manifest_path).await {
421 Ok(Some(m)) => m,
422 _ => PatchManifest::new(),
423 };
424
425 if !params.json && !params.silent {
426 eprintln!("\nDownloading {} patch(es)...", selected.len());
427 }
428
429 let mut patches_added = 0;
430 let mut patches_skipped = 0;
431 let mut patches_failed = 0;
432 let mut downloaded_patches: Vec<serde_json::Value> = Vec::new();
433 let mut updates: Vec<String> = Vec::new();
434
435 for search_result in selected {
436 if let Some(existing) = manifest.patches.get(&search_result.purl) {
438 if existing.uuid != search_result.uuid {
439 updates.push(search_result.purl.clone());
440 if !params.json && !params.silent {
441 eprintln!(
442 " [update] {} (replacing {})",
443 search_result.purl,
444 &existing.uuid[..8]
445 );
446 }
447 }
448 }
449
450 match api_client
451 .fetch_patch(effective_org, &search_result.uuid)
452 .await
453 {
454 Ok(Some(patch)) => {
455 let action = decide_patch_action(&manifest, &patch.purl, &patch.uuid);
459 if let PatchAction::Skipped = action {
460 if !params.json && !params.silent {
461 eprintln!(" [skip] {} (already in manifest)", patch.purl);
462 }
463 downloaded_patches.push(serde_json::json!({
464 "purl": patch.purl,
465 "uuid": patch.uuid,
466 "action": "skipped",
467 }));
468 patches_skipped += 1;
469 continue;
470 }
471
472 let mut patch_failed = false;
474 let mut files = HashMap::new();
475 for (file_path, file_info) in &patch.files {
476 if let (Some(ref before), Some(ref after)) =
477 (&file_info.before_hash, &file_info.after_hash)
478 {
479 files.insert(
480 file_path.clone(),
481 PatchFileInfo {
482 before_hash: before.clone(),
483 after_hash: after.clone(),
484 },
485 );
486 }
487
488 if let (Some(ref blob_content), Some(ref after_hash)) =
489 (&file_info.blob_content, &file_info.after_hash)
490 {
491 match base64_decode(blob_content) {
492 Ok(decoded) => {
493 let blob_path = blobs_dir.join(after_hash);
494 if let Err(e) = tokio::fs::write(&blob_path, &decoded).await {
495 if !params.json && !params.silent {
496 eprintln!(" [error] Failed to write blob for {}: {}", file_path, e);
497 }
498 patch_failed = true;
499 break;
500 }
501 }
502 Err(e) => {
503 if !params.json && !params.silent {
504 eprintln!(" [error] Failed to decode blob for {}: {}", file_path, e);
505 }
506 patch_failed = true;
507 break;
508 }
509 }
510 }
511
512 if let (Some(ref before_blob), Some(ref before_hash)) =
514 (&file_info.before_blob_content, &file_info.before_hash)
515 {
516 match base64_decode(before_blob) {
517 Ok(decoded) => {
518 if let Err(e) = tokio::fs::write(blobs_dir.join(before_hash), &decoded).await {
519 if !params.json && !params.silent {
520 eprintln!(" [error] Failed to write before-blob for {}: {}", file_path, e);
521 }
522 patch_failed = true;
523 break;
524 }
525 }
526 Err(e) => {
527 if !params.json && !params.silent {
528 eprintln!(" [error] Failed to decode before-blob for {}: {}", file_path, e);
529 }
530 patch_failed = true;
531 break;
532 }
533 }
534 }
535 }
536
537 if patch_failed {
538 patches_failed += 1;
539 downloaded_patches.push(serde_json::json!({
540 "purl": patch.purl,
541 "uuid": patch.uuid,
542 "action": "failed",
543 "error": "Blob decode or write failed",
544 }));
545 continue;
546 }
547
548 let vulnerabilities: HashMap<String, VulnerabilityInfo> = patch
549 .vulnerabilities
550 .iter()
551 .map(|(id, v)| {
552 (
553 id.clone(),
554 VulnerabilityInfo {
555 cves: v.cves.clone(),
556 summary: v.summary.clone(),
557 severity: v.severity.clone(),
558 description: v.description.clone(),
559 },
560 )
561 })
562 .collect();
563
564 manifest.patches.insert(
565 patch.purl.clone(),
566 PatchRecord {
567 uuid: patch.uuid.clone(),
568 exported_at: patch.published_at.clone(),
569 files,
570 vulnerabilities,
571 description: patch.description.clone(),
572 license: patch.license.clone(),
573 tier: patch.tier.clone(),
574 },
575 );
576
577 let mut action_record = match &action {
578 PatchAction::Updated { old_uuid } => {
579 if !params.json && !params.silent {
580 eprintln!(" [update] {}", patch.purl);
581 }
582 serde_json::json!({
583 "purl": patch.purl,
584 "uuid": patch.uuid,
585 "action": "updated",
586 "oldUuid": old_uuid,
587 })
588 }
589 _ => {
590 if !params.json && !params.silent {
591 eprintln!(" [add] {}", patch.purl);
592 }
593 serde_json::json!({
594 "purl": patch.purl,
595 "uuid": patch.uuid,
596 "action": "added",
597 })
598 }
599 };
600 merge_metadata(&mut action_record, patch_event_metadata(&patch));
605 downloaded_patches.push(action_record);
606 patches_added += 1;
607 }
608 Ok(None) => {
609 if !params.json && !params.silent {
610 eprintln!(" [fail] {} (could not fetch details)", search_result.purl);
611 }
612 downloaded_patches.push(serde_json::json!({
613 "purl": search_result.purl,
614 "uuid": search_result.uuid,
615 "action": "failed",
616 "error": "could not fetch details",
617 }));
618 patches_failed += 1;
619 }
620 Err(e) => {
621 if !params.json && !params.silent {
622 eprintln!(" [fail] {} ({e})", search_result.purl);
623 }
624 downloaded_patches.push(serde_json::json!({
625 "purl": search_result.purl,
626 "uuid": search_result.uuid,
627 "action": "failed",
628 "error": e.to_string(),
629 }));
630 patches_failed += 1;
631 }
632 }
633 }
634
635 if let Err(e) = write_manifest(&manifest_path, &manifest).await {
637 let err_json = serde_json::json!({
638 "status": "error",
639 "error": format!("Error writing manifest: {e}"),
640 });
641 if params.json {
642 println!("{}", serde_json::to_string_pretty(&err_json).unwrap());
643 } else {
644 eprintln!("Error writing manifest: {e}");
645 }
646 return (1, err_json);
647 }
648
649 if !params.json && !params.silent {
650 eprintln!("\nPatches saved to {}", manifest_path.display());
651 eprintln!(" Added: {patches_added}");
652 if patches_skipped > 0 {
653 eprintln!(" Skipped: {patches_skipped}");
654 }
655 if patches_failed > 0 {
656 eprintln!(" Failed: {patches_failed}");
657 }
658 if !updates.is_empty() {
659 eprintln!(" Updated: {}", updates.len());
660 }
661 }
662
663 let mut apply_succeeded = false;
665 if !params.save_only && patches_added > 0 {
666 if !params.json && !params.silent {
667 eprintln!("\nApplying patches...");
668 }
669 let apply_args = super::apply::ApplyArgs {
670 common: crate::args::GlobalArgs {
671 cwd: params.cwd.clone(),
672 manifest_path: manifest_path.display().to_string(),
673 global: params.global,
674 global_prefix: params.global_prefix.clone(),
675 silent: params.json || params.silent,
676 download_mode: params.download_mode.clone(),
677 ..crate::args::GlobalArgs::default()
678 },
679 force: false,
680 };
681 let code = super::apply::run(apply_args).await;
682 apply_succeeded = code == 0;
683 if code != 0 && !params.json && !params.silent {
684 eprintln!("\nSome patches could not be applied.");
685 }
686 }
687
688 let result_json = serde_json::json!({
689 "status": if patches_failed > 0 { "partial_failure" } else { "success" },
690 "found": selected.len(),
691 "downloaded": patches_added,
692 "skipped": patches_skipped,
693 "failed": patches_failed,
694 "applied": if apply_succeeded { patches_added } else { 0 },
695 "updated": updates.len(),
696 "patches": downloaded_patches,
697 });
698
699 let exit_code = if patches_failed > 0 || (!apply_succeeded && patches_added > 0 && !params.save_only) { 1 } else { 0 };
700 (exit_code, result_json)
701}
702
703pub async fn run(args: GetArgs) -> i32 {
704 let type_flags = [args.id, args.cve, args.ghsa, args.package]
706 .iter()
707 .filter(|&&f| f)
708 .count();
709 if type_flags > 1 {
710 if args.common.json {
711 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
712 "status": "error",
713 "error": "Only one of --id, --cve, --ghsa, or --package can be specified",
714 })).unwrap());
715 } else {
716 eprintln!("Error: Only one of --id, --cve, --ghsa, or --package can be specified");
717 }
718 return 1;
719 }
720 if args.one_off && args.save_only {
721 if args.common.json {
722 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
723 "status": "error",
724 "error": "--one-off and --save-only cannot be used together",
725 })).unwrap());
726 } else {
727 eprintln!("Error: --one-off and --save-only cannot be used together");
728 }
729 return 1;
730 }
731
732 apply_env_toggles(&args.common);
733 let overrides = args.common.api_client_overrides();
734 let (mut api_client, mut use_public_proxy) =
735 get_api_client_with_overrides(overrides.clone()).await;
736 let telemetry_token = api_client.api_token().cloned();
737 let telemetry_org = api_client.org_slug().cloned();
738 let download_mode = args.common.download_mode.clone();
739 let mut fallback_to_proxy = false;
744
745 let effective_org_slug: Option<&str> = None;
747
748 let id_type = if args.id {
750 IdentifierType::Uuid
751 } else if args.cve {
752 IdentifierType::Cve
753 } else if args.ghsa {
754 IdentifierType::Ghsa
755 } else if args.package {
756 IdentifierType::Package
757 } else {
758 match detect_identifier_type(&args.identifier) {
759 Some(t) => t,
760 None => {
761 if !args.common.json {
762 println!("Treating \"{}\" as a package name search", args.identifier);
763 }
764 IdentifierType::Package
765 }
766 }
767 };
768
769 if id_type == IdentifierType::Uuid {
771 if !args.common.json {
772 println!("Fetching patch by UUID: {}", args.identifier);
773 }
774 let mut fetch_result = api_client
775 .fetch_patch(effective_org_slug, &args.identifier)
776 .await;
777 if !use_public_proxy {
781 if let Err(ref e) = fetch_result {
782 if is_fallback_candidate(e) {
783 eprintln!(
784 "Warning: authenticated API returned {e}; \
785 falling back to public patch API proxy (free patches only)."
786 );
787 api_client = build_proxy_fallback_client(&overrides);
788 use_public_proxy = true;
789 fallback_to_proxy = true;
790 fetch_result = api_client
791 .fetch_patch(effective_org_slug, &args.identifier)
792 .await;
793 }
794 }
795 }
796 match fetch_result {
797 Ok(Some(patch)) => {
798 if patch.tier == "paid" && use_public_proxy {
799 track_patch_fetch_failed(
800 &patch.uuid,
801 "paid_required",
802 fallback_to_proxy,
803 telemetry_token.as_deref(),
804 telemetry_org.as_deref(),
805 )
806 .await;
807 if args.common.json {
808 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
809 "status": "paid_required",
810 "found": 1,
811 "downloaded": 0,
812 "applied": 0,
813 "patches": [{
814 "purl": patch.purl,
815 "uuid": patch.uuid,
816 "tier": "paid",
817 }],
818 })).unwrap());
819 } else {
820 println!("\nThis patch requires a paid subscription to download.");
821 println!("\n Patch: {}", patch.purl);
822 println!(" Tier: paid");
823 println!("\n Upgrade at: https://socket.dev/pricing\n");
824 }
825 return 0;
826 }
827
828 track_patch_fetched(
834 &patch.uuid,
835 &patch.tier,
836 &ecosystem_from_purl(&patch.purl),
837 &download_mode,
838 fallback_to_proxy,
839 telemetry_token.as_deref(),
840 telemetry_org.as_deref(),
841 )
842 .await;
843 return save_and_apply_patch(&args, &patch.purl, &patch.uuid, effective_org_slug)
845 .await;
846 }
847 Ok(None) => {
848 track_patch_fetch_failed(
849 &args.identifier,
850 "not_found",
851 fallback_to_proxy,
852 telemetry_token.as_deref(),
853 telemetry_org.as_deref(),
854 )
855 .await;
856 if args.common.json {
857 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
858 "status": "not_found",
859 "found": 0,
860 "downloaded": 0,
861 "applied": 0,
862 "patches": [],
863 })).unwrap());
864 } else {
865 println!("No patch found with UUID: {}", args.identifier);
866 }
867 return 0;
868 }
869 Err(e) => {
870 track_patch_fetch_failed(
871 &args.identifier,
872 &e,
873 fallback_to_proxy,
874 telemetry_token.as_deref(),
875 telemetry_org.as_deref(),
876 )
877 .await;
878 if args.common.json {
879 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
880 "status": "error",
881 "error": e.to_string(),
882 })).unwrap());
883 } else {
884 eprintln!("Error: {e}");
885 }
886 return 1;
887 }
888 }
889 }
890
891 let search_response: SearchResponse = match id_type {
893 IdentifierType::Cve => {
894 if !args.common.json {
895 println!("Searching patches for CVE: {}", args.identifier);
896 }
897 match api_client
898 .search_patches_by_cve(effective_org_slug, &args.identifier)
899 .await
900 {
901 Ok(r) => r,
902 Err(e) => {
903 track_patch_fetch_failed(
904 &args.identifier,
905 &e,
906 fallback_to_proxy,
907 telemetry_token.as_deref(),
908 telemetry_org.as_deref(),
909 )
910 .await;
911 if args.common.json {
912 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
913 "status": "error",
914 "error": e.to_string(),
915 })).unwrap());
916 } else {
917 eprintln!("Error: {e}");
918 }
919 return 1;
920 }
921 }
922 }
923 IdentifierType::Ghsa => {
924 if !args.common.json {
925 println!("Searching patches for GHSA: {}", args.identifier);
926 }
927 match api_client
928 .search_patches_by_ghsa(effective_org_slug, &args.identifier)
929 .await
930 {
931 Ok(r) => r,
932 Err(e) => {
933 track_patch_fetch_failed(
934 &args.identifier,
935 &e,
936 fallback_to_proxy,
937 telemetry_token.as_deref(),
938 telemetry_org.as_deref(),
939 )
940 .await;
941 if args.common.json {
942 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
943 "status": "error",
944 "error": e.to_string(),
945 })).unwrap());
946 } else {
947 eprintln!("Error: {e}");
948 }
949 return 1;
950 }
951 }
952 }
953 IdentifierType::Purl => {
954 if !args.common.json {
955 println!("Searching patches for PURL: {}", args.identifier);
956 }
957 match api_client
958 .search_patches_by_package(effective_org_slug, &args.identifier)
959 .await
960 {
961 Ok(r) => r,
962 Err(e) => {
963 track_patch_fetch_failed(
964 &args.identifier,
965 &e,
966 fallback_to_proxy,
967 telemetry_token.as_deref(),
968 telemetry_org.as_deref(),
969 )
970 .await;
971 if args.common.json {
972 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
973 "status": "error",
974 "error": e.to_string(),
975 })).unwrap());
976 } else {
977 eprintln!("Error: {e}");
978 }
979 return 1;
980 }
981 }
982 }
983 IdentifierType::Package => {
984 if !args.common.json {
985 println!("Enumerating packages...");
986 }
987 let crawler_options = CrawlerOptions {
988 cwd: args.common.cwd.clone(),
989 global: args.common.global,
990 global_prefix: args.common.global_prefix.clone(),
991 batch_size: 100,
992 };
993 let (all_packages, _) = crawl_all_ecosystems(&crawler_options).await;
994
995 if all_packages.is_empty() {
996 if args.common.json {
997 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
998 "status": "no_packages",
999 "found": 0,
1000 "downloaded": 0,
1001 "applied": 0,
1002 "patches": [],
1003 })).unwrap());
1004 } else if args.common.global {
1005 println!("No global packages found.");
1006 } else {
1007 #[allow(unused_mut)]
1008 let mut install_cmds = String::from("npm/yarn/pnpm/pip");
1009 #[cfg(feature = "cargo")]
1010 install_cmds.push_str("/cargo");
1011 #[cfg(feature = "golang")]
1012 install_cmds.push_str("/go");
1013 #[cfg(feature = "maven")]
1014 install_cmds.push_str("/mvn");
1015 #[cfg(feature = "composer")]
1016 install_cmds.push_str("/composer");
1017 println!("No packages found. Run {install_cmds} install first.");
1018 }
1019 return 0;
1020 }
1021
1022 if !args.common.json {
1023 println!("Found {} packages", all_packages.len());
1024 }
1025
1026 let matches = fuzzy_match_packages(&args.identifier, &all_packages, 20);
1027
1028 if matches.is_empty() {
1029 if args.common.json {
1030 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1031 "status": "no_match",
1032 "found": 0,
1033 "downloaded": 0,
1034 "applied": 0,
1035 "patches": [],
1036 })).unwrap());
1037 } else {
1038 println!("No packages matching \"{}\" found.", args.identifier);
1039 }
1040 return 0;
1041 }
1042
1043 if !args.common.json {
1044 println!(
1045 "Found {} matching package(s), checking for available patches...",
1046 matches.len()
1047 );
1048 }
1049
1050 let best_match = &matches[0];
1052 match api_client
1053 .search_patches_by_package(effective_org_slug, &best_match.purl)
1054 .await
1055 {
1056 Ok(r) => r,
1057 Err(e) => {
1058 track_patch_fetch_failed(
1059 &args.identifier,
1060 &e,
1061 fallback_to_proxy,
1062 telemetry_token.as_deref(),
1063 telemetry_org.as_deref(),
1064 )
1065 .await;
1066 if args.common.json {
1067 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1068 "status": "error",
1069 "error": e.to_string(),
1070 })).unwrap());
1071 } else {
1072 eprintln!("Error: {e}");
1073 }
1074 return 1;
1075 }
1076 }
1077 }
1078 _ => unreachable!(),
1079 };
1080
1081 if search_response.patches.is_empty() {
1082 if args.common.json {
1083 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1084 "status": "not_found",
1085 "found": 0,
1086 "downloaded": 0,
1087 "applied": 0,
1088 "patches": [],
1089 })).unwrap());
1090 } else {
1091 println!(
1092 "No patches found for {}: {}",
1093 id_type, args.identifier
1094 );
1095 }
1096 return 0;
1097 }
1098
1099 if !args.common.json {
1100 display_search_results(&search_response.patches, search_response.can_access_paid_patches);
1101 }
1102
1103 let accessible: Vec<_> = search_response
1105 .patches
1106 .iter()
1107 .filter(|p| p.tier == "free" || search_response.can_access_paid_patches)
1108 .cloned()
1109 .collect();
1110
1111 if accessible.is_empty() {
1112 if args.common.json {
1113 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1114 "status": "paid_required",
1115 "found": search_response.patches.len(),
1116 "downloaded": 0,
1117 "applied": 0,
1118 "patches": search_response.patches.iter().map(|p| serde_json::json!({
1119 "purl": p.purl,
1120 "uuid": p.uuid,
1121 "tier": p.tier,
1122 })).collect::<Vec<_>>(),
1123 })).unwrap());
1124 } else {
1125 println!("\nAll available patches require a paid subscription.");
1126 println!("\n Upgrade at: https://socket.dev/pricing\n");
1127 }
1128 return 0;
1129 }
1130
1131 let selected = match select_patches(
1133 &accessible,
1134 search_response.can_access_paid_patches,
1135 args.common.json,
1136 ) {
1137 Ok(s) => s,
1138 Err(code) => return code,
1139 };
1140
1141 if selected.is_empty() {
1142 if !args.common.json {
1143 println!("No patches selected.");
1144 }
1145 return 0;
1146 }
1147
1148 let prompt = format!("Download {} patch(es)?", selected.len());
1150 if !confirm(&prompt, true, args.common.yes, args.common.json) {
1151 if !args.common.json {
1152 println!("Download cancelled.");
1153 }
1154 return 0;
1155 }
1156
1157 let params = DownloadParams {
1159 cwd: args.common.cwd.clone(),
1160 org: args.common.org.clone(),
1161 save_only: args.save_only,
1162 one_off: args.one_off,
1163 global: args.common.global,
1164 global_prefix: args.common.global_prefix.clone(),
1165 json: args.common.json,
1166 silent: false,
1167 download_mode: args.common.download_mode.clone(),
1168 api_overrides: args.common.api_client_overrides(),
1169 };
1170
1171 let (code, result_json) = download_and_apply_patches(&selected, ¶ms).await;
1172
1173 if args.common.json {
1174 println!("{}", serde_json::to_string_pretty(&result_json).unwrap());
1175 }
1176
1177 code
1178}
1179
1180fn display_search_results(patches: &[PatchSearchResult], can_access_paid: bool) {
1181 println!("\nFound patches:\n");
1182
1183 for (i, patch) in patches.iter().enumerate() {
1184 let tier_label = if patch.tier == "paid" {
1185 " [PAID]"
1186 } else {
1187 " [FREE]"
1188 };
1189 let access_label = if patch.tier == "paid" && !can_access_paid {
1190 " (no access)"
1191 } else {
1192 ""
1193 };
1194
1195 println!(" {}. {}{}{}", i + 1, patch.purl, tier_label, access_label);
1196 println!(" UUID: {}", patch.uuid);
1197 if !patch.description.is_empty() {
1198 let desc = if patch.description.len() > 80 {
1199 format!("{}...", &patch.description[..77])
1200 } else {
1201 patch.description.clone()
1202 };
1203 println!(" Description: {desc}");
1204 }
1205
1206 let vuln_ids: Vec<_> = patch.vulnerabilities.keys().collect();
1207 if !vuln_ids.is_empty() {
1208 let vuln_summary: Vec<String> = patch
1209 .vulnerabilities
1210 .iter()
1211 .map(|(id, vuln)| {
1212 let cves = if vuln.cves.is_empty() {
1213 id.to_string()
1214 } else {
1215 vuln.cves.join(", ")
1216 };
1217 format!("{cves} ({})", vuln.severity)
1218 })
1219 .collect();
1220 println!(" Fixes: {}", vuln_summary.join(", "));
1221 }
1222 println!();
1223 }
1224}
1225
1226async fn save_and_apply_patch(
1227 args: &GetArgs,
1228 _purl: &str,
1229 uuid: &str,
1230 _org_slug: Option<&str>,
1231) -> i32 {
1232 let (api_client, _) =
1234 get_api_client_with_overrides(args.common.api_client_overrides()).await;
1235 let effective_org: Option<&str> = None; let patch = match api_client.fetch_patch(effective_org, uuid).await {
1238 Ok(Some(p)) => p,
1239 Ok(None) => {
1240 if args.common.json {
1241 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1242 "status": "not_found",
1243 "found": 0,
1244 "downloaded": 0,
1245 "applied": 0,
1246 "patches": [],
1247 })).unwrap());
1248 } else {
1249 println!("No patch found with UUID: {uuid}");
1250 }
1251 return 0;
1252 }
1253 Err(e) => {
1254 if args.common.json {
1255 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1256 "status": "error",
1257 "error": e.to_string(),
1258 })).unwrap());
1259 } else {
1260 eprintln!("Error: {e}");
1261 }
1262 return 1;
1263 }
1264 };
1265
1266 let socket_dir = args.common.cwd.join(".socket");
1267 let blobs_dir = socket_dir.join("blobs");
1268 let manifest_path = socket_dir.join("manifest.json");
1269
1270 if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await {
1271 if args.common.json {
1272 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1273 "status": "error",
1274 "error": format!("Failed to create blobs directory: {}", e),
1275 })).unwrap());
1276 } else {
1277 eprintln!("Error: Failed to create blobs directory: {}", e);
1278 }
1279 return 1;
1280 }
1281
1282 let mut manifest = match read_manifest(&manifest_path).await {
1283 Ok(Some(m)) => m,
1284 _ => PatchManifest::new(),
1285 };
1286
1287 let mut blob_failed = false;
1289 let mut files = HashMap::new();
1290 for (file_path, file_info) in &patch.files {
1291 if let Some(ref after) = file_info.after_hash {
1292 files.insert(
1293 file_path.clone(),
1294 PatchFileInfo {
1295 before_hash: file_info
1296 .before_hash
1297 .clone()
1298 .unwrap_or_default(),
1299 after_hash: after.clone(),
1300 },
1301 );
1302 }
1303 if let (Some(ref blob_content), Some(ref after_hash)) =
1304 (&file_info.blob_content, &file_info.after_hash)
1305 {
1306 match base64_decode(blob_content) {
1307 Ok(decoded) => {
1308 if let Err(e) = tokio::fs::write(blobs_dir.join(after_hash), &decoded).await {
1309 if !args.common.json {
1310 eprintln!(" [error] Failed to write blob for {}: {}", file_path, e);
1311 }
1312 blob_failed = true;
1313 break;
1314 }
1315 }
1316 Err(e) => {
1317 if !args.common.json {
1318 eprintln!(" [error] Failed to decode blob for {}: {}", file_path, e);
1319 }
1320 blob_failed = true;
1321 break;
1322 }
1323 }
1324 }
1325 if let (Some(ref before_blob), Some(ref before_hash)) =
1327 (&file_info.before_blob_content, &file_info.before_hash)
1328 {
1329 match base64_decode(before_blob) {
1330 Ok(decoded) => {
1331 if let Err(e) = tokio::fs::write(blobs_dir.join(before_hash), &decoded).await {
1332 if !args.common.json {
1333 eprintln!(" [error] Failed to write before-blob for {}: {}", file_path, e);
1334 }
1335 blob_failed = true;
1336 break;
1337 }
1338 }
1339 Err(e) => {
1340 if !args.common.json {
1341 eprintln!(" [error] Failed to decode before-blob for {}: {}", file_path, e);
1342 }
1343 blob_failed = true;
1344 break;
1345 }
1346 }
1347 }
1348 }
1349
1350 if blob_failed {
1351 if args.common.json {
1352 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1353 "status": "error",
1354 "found": 1,
1355 "downloaded": 0,
1356 "applied": 0,
1357 "error": "Blob decode or write failed",
1358 "patches": [{
1359 "purl": patch.purl,
1360 "uuid": patch.uuid,
1361 "action": "failed",
1362 "error": "Blob decode or write failed",
1363 }],
1364 })).unwrap());
1365 } else {
1366 eprintln!("Error: Blob decode or write failed for patch {}", patch.purl);
1367 }
1368 return 1;
1369 }
1370
1371 let vulnerabilities: HashMap<String, VulnerabilityInfo> = patch
1372 .vulnerabilities
1373 .iter()
1374 .map(|(id, v)| {
1375 (
1376 id.clone(),
1377 VulnerabilityInfo {
1378 cves: v.cves.clone(),
1379 summary: v.summary.clone(),
1380 severity: v.severity.clone(),
1381 description: v.description.clone(),
1382 },
1383 )
1384 })
1385 .collect();
1386
1387 let added = manifest
1388 .patches
1389 .get(&patch.purl)
1390 .is_none_or(|p| p.uuid != patch.uuid);
1391
1392 manifest.patches.insert(
1393 patch.purl.clone(),
1394 PatchRecord {
1395 uuid: patch.uuid.clone(),
1396 exported_at: patch.published_at.clone(),
1397 files,
1398 vulnerabilities,
1399 description: patch.description.clone(),
1400 license: patch.license.clone(),
1401 tier: patch.tier.clone(),
1402 },
1403 );
1404
1405 if let Err(e) = write_manifest(&manifest_path, &manifest).await {
1406 if args.common.json {
1407 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1408 "status": "error",
1409 "error": format!("Error writing manifest: {e}"),
1410 })).unwrap());
1411 } else {
1412 eprintln!("Error writing manifest: {e}");
1413 }
1414 return 1;
1415 }
1416
1417 if !args.common.json {
1418 println!("\nPatch saved to {}", manifest_path.display());
1419 if added {
1420 println!(" Added: 1");
1421 } else {
1422 println!(" Skipped: 1 (already exists)");
1423 }
1424 }
1425
1426 let mut apply_succeeded = false;
1427 if !args.save_only && added {
1428 if !args.common.json {
1429 println!("\nApplying patches...");
1430 }
1431 let apply_args = super::apply::ApplyArgs {
1432 common: crate::args::GlobalArgs {
1433 cwd: args.common.cwd.clone(),
1434 manifest_path: manifest_path.display().to_string(),
1435 global: args.common.global,
1436 global_prefix: args.common.global_prefix.clone(),
1437 silent: args.common.json,
1438 download_mode: args.common.download_mode.clone(),
1439 ..crate::args::GlobalArgs::default()
1440 },
1441 force: false,
1442 };
1443 let code = super::apply::run(apply_args).await;
1444 apply_succeeded = code == 0;
1445 if code != 0 && !args.common.json {
1446 eprintln!("\nSome patches could not be applied.");
1447 }
1448 }
1449
1450 if args.common.json {
1451 let mut patch_record = serde_json::json!({
1452 "purl": patch.purl,
1453 "uuid": patch.uuid,
1454 "action": if added { "added" } else { "skipped" },
1455 });
1456 if added {
1457 merge_metadata(&mut patch_record, patch_event_metadata(&patch));
1460 }
1461 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1462 "status": "success",
1463 "found": 1,
1464 "downloaded": if added { 1 } else { 0 },
1465 "applied": if apply_succeeded { 1 } else { 0 },
1466 "patches": [patch_record],
1467 })).unwrap());
1468 }
1469
1470 if !apply_succeeded && added && !args.save_only { 1 } else { 0 }
1471}
1472
1473fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
1474 let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1475 let mut table = [255u8; 256];
1476 for (i, &c) in chars.iter().enumerate() {
1477 table[c as usize] = i as u8;
1478 }
1479
1480 let input = input.as_bytes();
1481 let mut output = Vec::with_capacity(input.len() * 3 / 4);
1482
1483 let mut buf = 0u32;
1484 let mut bits = 0u32;
1485
1486 for &b in input {
1487 if b == b'=' || b == b'\n' || b == b'\r' {
1488 continue;
1489 }
1490 let val = table[b as usize];
1491 if val == 255 {
1492 return Err(format!("Invalid base64 character: {}", b as char));
1493 }
1494 buf = (buf << 6) | val as u32;
1495 bits += 6;
1496 if bits >= 8 {
1497 bits -= 8;
1498 output.push((buf >> bits) as u8);
1499 buf &= (1 << bits) - 1;
1500 }
1501 }
1502
1503 Ok(output)
1504}
1505
1506#[cfg(test)]
1507mod tests {
1508 use super::*;
1509 use socket_patch_core::api::types::VulnerabilityResponse;
1510 use std::collections::HashMap;
1511
1512 #[test]
1515 fn detect_uuid_lowercase() {
1516 assert_eq!(
1517 detect_identifier_type("80630680-4da6-45f9-bba8-b888e0ffd58c"),
1518 Some(IdentifierType::Uuid)
1519 );
1520 }
1521
1522 #[test]
1523 fn detect_uuid_uppercase() {
1524 assert_eq!(
1526 detect_identifier_type("80630680-4DA6-45F9-BBA8-B888E0FFD58C"),
1527 Some(IdentifierType::Uuid)
1528 );
1529 }
1530
1531 #[test]
1532 fn detect_cve_uppercase() {
1533 assert_eq!(
1534 detect_identifier_type("CVE-2021-44906"),
1535 Some(IdentifierType::Cve)
1536 );
1537 }
1538
1539 #[test]
1540 fn detect_cve_lowercase() {
1541 assert_eq!(
1543 detect_identifier_type("cve-2021-44906"),
1544 Some(IdentifierType::Cve)
1545 );
1546 }
1547
1548 #[test]
1549 fn detect_ghsa_uppercase() {
1550 assert_eq!(
1551 detect_identifier_type("GHSA-abcd-1234-wxyz"),
1552 Some(IdentifierType::Ghsa)
1553 );
1554 }
1555
1556 #[test]
1557 fn detect_ghsa_lowercase() {
1558 assert_eq!(
1560 detect_identifier_type("ghsa-abcd-1234-wxyz"),
1561 Some(IdentifierType::Ghsa)
1562 );
1563 }
1564
1565 #[test]
1566 fn detect_purl() {
1567 assert_eq!(
1568 detect_identifier_type("pkg:npm/foo@1.0"),
1569 Some(IdentifierType::Purl)
1570 );
1571 }
1572
1573 #[test]
1574 fn detect_package_name_returns_none() {
1575 assert_eq!(detect_identifier_type("minimist"), None);
1578 }
1579
1580 #[test]
1581 fn detect_malformed_cve_returns_none() {
1582 assert_eq!(detect_identifier_type("CVE-not-a-year"), None);
1583 }
1584
1585 #[test]
1586 fn detect_empty_string_returns_none() {
1587 assert_eq!(detect_identifier_type(""), None);
1588 }
1589
1590 fn mk_patch(
1593 uuid: &str,
1594 purl: &str,
1595 tier: &str,
1596 published_at: &str,
1597 ) -> PatchSearchResult {
1598 PatchSearchResult {
1599 uuid: uuid.into(),
1600 purl: purl.into(),
1601 published_at: published_at.into(),
1602 description: format!("desc-{uuid}"),
1603 license: "MIT".into(),
1604 tier: tier.into(),
1605 vulnerabilities: HashMap::<String, VulnerabilityResponse>::new(),
1606 }
1607 }
1608
1609 #[test]
1610 fn select_free_user_one_free_patch_returns_it() {
1611 let patches = vec![mk_patch("u1", "pkg:npm/foo@1.0", "free", "2024-01-01")];
1612 let out = select_patches(&patches, false, false).expect("ok");
1613 assert_eq!(out.len(), 1);
1614 assert_eq!(out[0].uuid, "u1");
1615 }
1616
1617 #[test]
1618 fn select_paid_user_prefers_paid_over_free_same_purl() {
1619 let patches = vec![
1620 mk_patch("free1", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1621 mk_patch("paid1", "pkg:npm/foo@1.0", "paid", "2024-01-01"),
1622 ];
1623 let out = select_patches(&patches, true, false).expect("ok");
1624 assert_eq!(out.len(), 1);
1625 assert_eq!(out[0].uuid, "paid1");
1627 assert_eq!(out[0].tier, "paid");
1628 }
1629
1630 #[test]
1631 fn select_paid_user_picks_most_recent_paid() {
1632 let patches = vec![
1633 mk_patch("old", "pkg:npm/foo@1.0", "paid", "2024-01-01"),
1634 mk_patch("new", "pkg:npm/foo@1.0", "paid", "2024-06-01"),
1635 ];
1636 let out = select_patches(&patches, true, false).expect("ok");
1637 assert_eq!(out.len(), 1);
1638 assert_eq!(out[0].uuid, "new");
1639 }
1640
1641 #[test]
1642 fn select_paid_user_falls_back_to_most_recent_free_when_no_paid() {
1643 let patches = vec![
1644 mk_patch("old", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1645 mk_patch("new", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1646 ];
1647 let out = select_patches(&patches, true, false).expect("ok");
1648 assert_eq!(out.len(), 1);
1649 assert_eq!(out[0].uuid, "new");
1650 }
1651
1652 #[test]
1653 fn select_free_user_multi_free_json_mode_errors() {
1654 let patches = vec![
1657 mk_patch("a", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1658 mk_patch("b", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1659 ];
1660 let err = select_patches(&patches, false, true).expect_err("should fail");
1661 assert_eq!(err, 1);
1662 }
1663
1664 #[test]
1665 fn select_empty_input_returns_empty() {
1666 let out = select_patches(&[], false, false).expect("ok");
1667 assert!(out.is_empty());
1668 let out = select_patches(&[], true, false).expect("ok");
1669 assert!(out.is_empty());
1670 let out = select_patches(&[], false, true).expect("ok");
1671 assert!(out.is_empty());
1672 }
1673
1674 #[test]
1675 fn select_free_user_paid_filtered_out_then_single_free_auto_selects() {
1676 let patches = vec![
1680 mk_patch("paid", "pkg:npm/foo@1.0", "paid", "2024-06-01"),
1681 mk_patch("free", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1682 ];
1683 let out = select_patches(&patches, false, false).expect("ok");
1684 assert_eq!(out.len(), 1);
1685 assert_eq!(out[0].uuid, "free");
1686 assert_eq!(out[0].tier, "free");
1687 }
1688
1689 fn manifest_with_entry(purl: &str, uuid: &str) -> PatchManifest {
1694 let mut m = PatchManifest::new();
1695 m.patches.insert(
1696 purl.to_string(),
1697 PatchRecord {
1698 uuid: uuid.to_string(),
1699 exported_at: String::new(),
1700 files: HashMap::new(),
1701 vulnerabilities: HashMap::new(),
1702 description: String::new(),
1703 license: String::new(),
1704 tier: "free".to_string(),
1705 },
1706 );
1707 m
1708 }
1709
1710 #[test]
1711 fn decide_patch_action_added_when_purl_absent() {
1712 let manifest = PatchManifest::new();
1713 assert_eq!(
1714 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-a"),
1715 PatchAction::Added,
1716 );
1717 }
1718
1719 #[test]
1720 fn decide_patch_action_skipped_when_same_uuid() {
1721 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1722 assert_eq!(
1723 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-a"),
1724 PatchAction::Skipped,
1725 );
1726 }
1727
1728 #[test]
1729 fn decide_patch_action_updated_when_different_uuid() {
1730 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1731 assert_eq!(
1732 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-b"),
1733 PatchAction::Updated {
1734 old_uuid: "uuid-a".to_string()
1735 },
1736 );
1737 }
1738
1739 #[test]
1740 fn decide_patch_action_added_for_different_purl_even_with_overlapping_manifest() {
1741 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1745 assert_eq!(
1746 decide_patch_action(&manifest, "pkg:npm/bar@2.0", "uuid-a"),
1747 PatchAction::Added,
1748 );
1749 }
1750
1751 #[test]
1758 fn severity_rank_orders_canonical_labels() {
1759 assert!(severity_rank("critical") > severity_rank("high"));
1760 assert!(severity_rank("high") > severity_rank("medium"));
1761 assert!(severity_rank("medium") > severity_rank("low"));
1762 assert_eq!(severity_rank("moderate"), severity_rank("medium"));
1764 assert!(severity_rank("low") > severity_rank(""));
1766 assert!(severity_rank("low") > severity_rank("unknown"));
1767 }
1768
1769 #[test]
1770 fn max_vuln_severity_picks_highest() {
1771 let mut vulns = HashMap::new();
1772 vulns.insert(
1773 "GHSA-low".into(),
1774 VulnerabilityResponse {
1775 cves: vec!["CVE-low".into()],
1776 summary: String::new(),
1777 severity: "low".into(),
1778 description: String::new(),
1779 },
1780 );
1781 vulns.insert(
1782 "GHSA-crit".into(),
1783 VulnerabilityResponse {
1784 cves: vec!["CVE-crit".into()],
1785 summary: String::new(),
1786 severity: "critical".into(),
1787 description: String::new(),
1788 },
1789 );
1790 vulns.insert(
1791 "GHSA-mod".into(),
1792 VulnerabilityResponse {
1793 cves: vec!["CVE-mod".into()],
1794 summary: String::new(),
1795 severity: "moderate".into(),
1796 description: String::new(),
1797 },
1798 );
1799 assert_eq!(max_vuln_severity(&vulns).as_deref(), Some("critical"));
1800 }
1801
1802 #[test]
1803 fn max_vuln_severity_returns_none_for_empty() {
1804 assert_eq!(max_vuln_severity(&HashMap::new()), None);
1805 }
1806
1807 #[test]
1808 fn patch_event_metadata_includes_all_keys() {
1809 let mut vulns = HashMap::new();
1810 vulns.insert(
1811 "GHSA-aaaa-bbbb-cccc".into(),
1812 VulnerabilityResponse {
1813 cves: vec!["CVE-2024-12345".into()],
1814 summary: "Prototype Pollution".into(),
1815 severity: "high".into(),
1816 description: "merge() does not check Object.prototype".into(),
1817 },
1818 );
1819 let patch = PatchResponse {
1820 uuid: "11111111-1111-4111-8111-111111111111".into(),
1821 purl: "pkg:npm/minimist@1.2.2".into(),
1822 published_at: "2024-01-01T00:00:00Z".into(),
1823 files: HashMap::new(),
1824 vulnerabilities: vulns,
1825 description: "Fixes prototype pollution in minimist".into(),
1826 license: "MIT".into(),
1827 tier: "free".into(),
1828 };
1829 let meta = patch_event_metadata(&patch);
1830 assert_eq!(meta["description"], "Fixes prototype pollution in minimist");
1831 assert_eq!(meta["license"], "MIT");
1832 assert_eq!(meta["tier"], "free");
1833 assert_eq!(meta["exportedAt"], "2024-01-01T00:00:00Z");
1834 assert_eq!(meta["severity"], "high");
1835 let vulns_out = meta["vulnerabilities"].as_array().unwrap();
1836 assert_eq!(vulns_out.len(), 1);
1837 assert_eq!(vulns_out[0]["id"], "GHSA-aaaa-bbbb-cccc");
1838 assert_eq!(vulns_out[0]["cves"][0], "CVE-2024-12345");
1839 assert_eq!(vulns_out[0]["severity"], "high");
1840 assert_eq!(vulns_out[0]["summary"], "Prototype Pollution");
1841 }
1842
1843 #[test]
1844 fn patch_event_metadata_sorts_vulnerabilities_by_id() {
1845 let mut vulns = HashMap::new();
1849 for id in ["GHSA-zzz", "GHSA-aaa", "GHSA-mmm"] {
1850 vulns.insert(
1851 id.into(),
1852 VulnerabilityResponse {
1853 cves: Vec::new(),
1854 summary: String::new(),
1855 severity: "low".into(),
1856 description: String::new(),
1857 },
1858 );
1859 }
1860 let patch = PatchResponse {
1861 uuid: String::new(),
1862 purl: String::new(),
1863 published_at: String::new(),
1864 files: HashMap::new(),
1865 vulnerabilities: vulns,
1866 description: String::new(),
1867 license: String::new(),
1868 tier: String::new(),
1869 };
1870 let meta = patch_event_metadata(&patch);
1871 let ids: Vec<&str> = meta["vulnerabilities"]
1872 .as_array()
1873 .unwrap()
1874 .iter()
1875 .map(|v| v["id"].as_str().unwrap())
1876 .collect();
1877 assert_eq!(ids, ["GHSA-aaa", "GHSA-mmm", "GHSA-zzz"]);
1878 }
1879
1880 #[test]
1881 fn patch_event_metadata_omits_severity_when_no_vulns() {
1882 let patch = PatchResponse {
1883 uuid: String::new(),
1884 purl: String::new(),
1885 published_at: "ts".into(),
1886 files: HashMap::new(),
1887 vulnerabilities: HashMap::new(),
1888 description: "desc".into(),
1889 license: "MIT".into(),
1890 tier: "free".into(),
1891 };
1892 let meta = patch_event_metadata(&patch);
1893 assert!(meta.as_object().unwrap().get("severity").is_none());
1897 assert_eq!(meta["vulnerabilities"].as_array().unwrap().len(), 0);
1900 }
1901}