1use clap::Args;
2use regex::Regex;
3use socket_patch_core::api::client::{
4 build_proxy_fallback_client, get_api_client_with_overrides, is_fallback_candidate,
5};
6use socket_patch_core::api::types::{
7 PatchResponse, PatchSearchResult, SearchResponse, VulnerabilityResponse,
8};
9use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem};
10use socket_patch_core::manifest::operations::{read_manifest, write_manifest};
11use socket_patch_core::manifest::schema::{
12 PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo,
13};
14use socket_patch_core::patch::apply::select_installed_variants;
15use socket_patch_core::utils::fuzzy_match::fuzzy_match_packages;
16use socket_patch_core::utils::purl::{is_purl, strip_purl_qualifiers};
17use socket_patch_core::utils::telemetry::{track_patch_fetch_failed, track_patch_fetched};
18use std::collections::HashMap;
19use std::fmt;
20use std::path::{Path, PathBuf};
21
22use crate::args::{apply_env_toggles, GlobalArgs};
23use crate::ecosystem_dispatch::{crawl_all_ecosystems, find_packages_for_rollback, partition_purls};
24use crate::output::{confirm, select_one, SelectError};
25
26fn ecosystem_from_purl(purl: &str) -> String {
31 purl.strip_prefix("pkg:")
32 .and_then(|rest| rest.split('/').next())
33 .unwrap_or("")
34 .to_string()
35}
36
37#[derive(Debug, PartialEq, Eq, Clone)]
41pub(crate) enum PatchAction {
42 Added,
44 Updated { old_uuid: String },
47 Skipped,
49}
50
51pub(crate) fn decide_patch_action(
54 manifest: &PatchManifest,
55 purl: &str,
56 new_uuid: &str,
57) -> PatchAction {
58 match manifest.patches.get(purl) {
59 Some(existing) if existing.uuid == new_uuid => PatchAction::Skipped,
60 Some(existing) => PatchAction::Updated {
61 old_uuid: existing.uuid.clone(),
62 },
63 None => PatchAction::Added,
64 }
65}
66
67pub(crate) fn severity_rank(severity: &str) -> u8 {
71 match severity.to_ascii_lowercase().as_str() {
72 "critical" => 4,
73 "high" => 3,
74 "moderate" | "medium" => 2,
76 "low" => 1,
77 _ => 0,
78 }
79}
80
81pub(crate) fn max_vuln_severity(
85 vulns: &HashMap<String, VulnerabilityResponse>,
86) -> Option<String> {
87 vulns
88 .values()
89 .max_by_key(|v| severity_rank(&v.severity))
90 .map(|v| v.severity.clone())
91}
92
93pub(crate) fn patch_event_metadata(patch: &PatchResponse) -> serde_json::Value {
104 let mut vulns: Vec<serde_json::Value> = patch
105 .vulnerabilities
106 .iter()
107 .map(|(id, v)| {
108 serde_json::json!({
109 "id": id,
110 "cves": v.cves,
111 "severity": v.severity,
112 "summary": v.summary,
113 "description": v.description,
114 })
115 })
116 .collect();
117 vulns.sort_by(|a, b| {
120 a["id"]
121 .as_str()
122 .unwrap_or("")
123 .cmp(b["id"].as_str().unwrap_or(""))
124 });
125
126 let mut meta = serde_json::Map::new();
127 meta.insert(
128 "description".into(),
129 serde_json::Value::String(patch.description.clone()),
130 );
131 meta.insert(
132 "license".into(),
133 serde_json::Value::String(patch.license.clone()),
134 );
135 meta.insert(
136 "tier".into(),
137 serde_json::Value::String(patch.tier.clone()),
138 );
139 meta.insert(
140 "exportedAt".into(),
141 serde_json::Value::String(patch.published_at.clone()),
142 );
143 if let Some(sev) = max_vuln_severity(&patch.vulnerabilities) {
144 meta.insert("severity".into(), serde_json::Value::String(sev));
145 }
146 meta.insert("vulnerabilities".into(), serde_json::Value::Array(vulns));
147 serde_json::Value::Object(meta)
148}
149
150fn merge_metadata(record: &mut serde_json::Value, meta: serde_json::Value) {
154 if let (Some(record_obj), serde_json::Value::Object(meta_obj)) =
155 (record.as_object_mut(), meta)
156 {
157 for (k, v) in meta_obj {
158 record_obj.insert(k, v);
159 }
160 }
161}
162
163fn print_json(v: &serde_json::Value) {
165 println!("{}", serde_json::to_string_pretty(v).unwrap());
166}
167
168pub(crate) fn truncate_with_ellipsis(s: &str, limit: usize) -> String {
175 if s.chars().count() <= limit {
176 s.to_string()
177 } else {
178 let head: String = s.chars().take(limit.saturating_sub(3)).collect();
179 format!("{head}...")
180 }
181}
182
183fn short_uuid(uuid: &str) -> &str {
189 uuid.get(..8).unwrap_or(uuid)
190}
191
192fn empty_result_json(status: &str) -> serde_json::Value {
196 serde_json::json!({
197 "status": status,
198 "found": 0,
199 "downloaded": 0,
200 "applied": 0,
201 "patches": [],
202 })
203}
204
205async fn report_fetch_failure(
209 identifier: &str,
210 error: impl std::fmt::Display,
211 fallback_to_proxy: bool,
212 api_token: Option<&str>,
213 org_slug: Option<&str>,
214 json: bool,
215) -> i32 {
216 let msg = error.to_string();
217 track_patch_fetch_failed(identifier, &msg, fallback_to_proxy, api_token, org_slug).await;
218 report_error(json, msg);
219 1
220}
221
222fn report_error(json: bool, message: impl std::fmt::Display) {
225 let message = message.to_string();
226 if json {
227 print_json(&serde_json::json!({"status": "error", "error": message}));
228 } else {
229 eprintln!("Error: {message}");
230 }
231}
232
233async fn write_blob_entry(
236 blobs_dir: &Path,
237 b64: &str,
238 hash: &str,
239 file_path: &str,
240 label: &str,
241) -> Result<(), String> {
242 let decoded = base64_decode(b64)
243 .map_err(|e| format!("Failed to decode {label} for {file_path}: {e}"))?;
244 tokio::fs::write(blobs_dir.join(hash), &decoded)
245 .await
246 .map_err(|e| format!("Failed to write {label} for {file_path}: {e}"))
247}
248
249async fn write_all_patch_blobs(
253 blobs_dir: &Path,
254 patch: &PatchResponse,
255 quiet: bool,
256) -> Result<(), ()> {
257 for (file_path, file_info) in &patch.files {
258 if let (Some(blob), Some(hash)) =
259 (&file_info.blob_content, &file_info.after_hash)
260 {
261 if let Err(e) = write_blob_entry(blobs_dir, blob, hash, file_path, "blob").await {
262 if !quiet {
263 eprintln!(" [error] {e}");
264 }
265 return Err(());
266 }
267 }
268 if let (Some(blob), Some(hash)) =
269 (&file_info.before_blob_content, &file_info.before_hash)
270 {
271 if let Err(e) =
272 write_blob_entry(blobs_dir, blob, hash, file_path, "before-blob").await
273 {
274 if !quiet {
275 eprintln!(" [error] {e}");
276 }
277 return Err(());
278 }
279 }
280 }
281 Ok(())
282}
283
284fn vulnerabilities_for_manifest(
287 vulns: &HashMap<String, VulnerabilityResponse>,
288) -> HashMap<String, VulnerabilityInfo> {
289 vulns
290 .iter()
291 .map(|(id, v)| {
292 (
293 id.clone(),
294 VulnerabilityInfo {
295 cves: v.cves.clone(),
296 summary: v.summary.clone(),
297 severity: v.severity.clone(),
298 description: v.description.clone(),
299 },
300 )
301 })
302 .collect()
303}
304
305fn build_patch_record(
310 patch: &PatchResponse,
311 files: HashMap<String, PatchFileInfo>,
312) -> PatchRecord {
313 PatchRecord {
314 uuid: patch.uuid.clone(),
315 exported_at: patch.published_at.clone(),
316 files,
317 vulnerabilities: vulnerabilities_for_manifest(&patch.vulnerabilities),
318 description: patch.description.clone(),
319 license: patch.license.clone(),
320 tier: patch.tier.clone(),
321 }
322}
323
324#[derive(Args)]
325pub struct GetArgs {
326 pub identifier: String,
328
329 #[command(flatten)]
330 pub common: GlobalArgs,
331
332 #[arg(long, default_value_t = false)]
334 pub id: bool,
335
336 #[arg(long, default_value_t = false)]
338 pub cve: bool,
339
340 #[arg(long, default_value_t = false)]
342 pub ghsa: bool,
343
344 #[arg(short = 'p', long = "package", default_value_t = false)]
346 pub package: bool,
347
348 #[arg(long = "save-only", alias = "no-apply", env = "SOCKET_SAVE_ONLY", default_value_t = false)]
350 pub save_only: bool,
351
352 #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
354 pub one_off: bool,
355
356 #[arg(
363 long = "all-releases",
364 env = "SOCKET_ALL_RELEASES",
365 default_value_t = false,
366 value_parser = clap::builder::BoolishValueParser::new(),
367 )]
368 pub all_releases: bool,
369}
370
371#[derive(Debug, Clone, Copy, PartialEq)]
372enum IdentifierType {
373 Uuid,
374 Cve,
375 Ghsa,
376 Purl,
377 Package,
378}
379
380impl fmt::Display for IdentifierType {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 match self {
383 IdentifierType::Uuid => write!(f, "UUID"),
384 IdentifierType::Cve => write!(f, "CVE"),
385 IdentifierType::Ghsa => write!(f, "GHSA"),
386 IdentifierType::Purl => write!(f, "PURL"),
387 IdentifierType::Package => write!(f, "package name"),
388 }
389 }
390}
391
392fn detect_identifier_type(identifier: &str) -> Option<IdentifierType> {
393 let uuid_re = Regex::new(r"(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$").unwrap();
394 let cve_re = Regex::new(r"(?i)^CVE-\d{4}-\d+$").unwrap();
395 let ghsa_re = Regex::new(r"(?i)^GHSA-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}$").unwrap();
396
397 if uuid_re.is_match(identifier) {
398 Some(IdentifierType::Uuid)
399 } else if cve_re.is_match(identifier) {
400 Some(IdentifierType::Cve)
401 } else if ghsa_re.is_match(identifier) {
402 Some(IdentifierType::Ghsa)
403 } else if is_purl(identifier) {
404 Some(IdentifierType::Purl)
405 } else {
406 None
407 }
408}
409
410pub fn select_patches(
419 patches: &[PatchSearchResult],
420 can_access_paid: bool,
421 is_json: bool,
422) -> Result<Vec<PatchSearchResult>, i32> {
423 let mut by_purl: HashMap<String, Vec<&PatchSearchResult>> = HashMap::new();
425 for p in patches {
426 if p.tier == "free" || can_access_paid {
427 by_purl.entry(p.purl.clone()).or_default().push(p);
428 }
429 }
430
431 let mut selected = Vec::new();
432
433 for (purl, mut group) in by_purl {
434 group.sort_by(|a, b| b.published_at.cmp(&a.published_at));
436
437 if can_access_paid {
438 let choice = group
440 .iter()
441 .find(|p| p.tier == "paid")
442 .or_else(|| group.first())
443 .unwrap();
444 selected.push((*choice).clone());
445 } else if group.len() == 1 {
446 selected.push(group[0].clone());
447 } else {
448 let options: Vec<String> = group
450 .iter()
451 .map(|p| {
452 let vuln_summary: Vec<String> = p
453 .vulnerabilities
454 .iter()
455 .map(|(id, v)| {
456 if v.cves.is_empty() {
457 id.clone()
458 } else {
459 v.cves.join(", ")
460 }
461 })
462 .collect();
463 let vulns = if vuln_summary.is_empty() {
464 String::new()
465 } else {
466 format!(" (fixes: {})", vuln_summary.join(", "))
467 };
468 let desc = truncate_with_ellipsis(&p.description, 60);
469 format!("{} [{}]{} - {}", p.uuid, p.tier, vulns, desc)
470 })
471 .collect();
472
473 match select_one(
474 &format!("Multiple patches available for {purl}. Select one:"),
475 &options,
476 is_json,
477 ) {
478 Ok(idx) => {
479 selected.push(group[idx].clone());
480 }
481 Err(SelectError::JsonModeNeedsExplicit) => {
482 let options_json: Vec<serde_json::Value> = group
483 .iter()
484 .map(|p| {
485 let vulns: Vec<serde_json::Value> = p
486 .vulnerabilities
487 .iter()
488 .map(|(id, v)| {
489 serde_json::json!({
490 "id": id,
491 "cves": v.cves,
492 "severity": v.severity,
493 "summary": v.summary,
494 })
495 })
496 .collect();
497 serde_json::json!({
498 "uuid": p.uuid,
499 "tier": p.tier,
500 "published_at": p.published_at,
501 "description": p.description,
502 "vulnerabilities": vulns,
503 })
504 })
505 .collect();
506 println!(
507 "{}",
508 serde_json::to_string_pretty(&serde_json::json!({
509 "status": "selection_required",
510 "error": format!("Multiple patches available for {purl}. Specify --id <UUID> to select one."),
511 "purl": purl,
512 "options": options_json,
513 }))
514 .unwrap()
515 );
516 return Err(1);
517 }
518 Err(SelectError::Cancelled) => {
519 eprintln!("Selection cancelled.");
520 return Err(0);
521 }
522 }
523 }
524 }
525
526 Ok(selected)
527}
528
529pub struct DownloadParams {
531 pub cwd: PathBuf,
532 pub org: Option<String>,
533 pub save_only: bool,
534 pub one_off: bool,
535 pub global: bool,
536 pub global_prefix: Option<PathBuf>,
537 pub json: bool,
538 pub silent: bool,
539 pub download_mode: String,
541 pub api_overrides: socket_patch_core::api::client::ApiClientEnvOverrides,
546 pub all_releases: bool,
552}
553
554async fn filter_to_installed_releases(
577 selected: &[PatchSearchResult],
578 params: &DownloadParams,
579 api_client: &socket_patch_core::api::client::ApiClient,
580 org: Option<&str>,
581) -> (Vec<PatchSearchResult>, Vec<String>) {
582 let mut variant_groups: HashMap<String, Vec<PatchSearchResult>> = HashMap::new();
587 let mut kept: Vec<PatchSearchResult> = Vec::new();
588 for sr in selected {
589 if Ecosystem::from_purl(&sr.purl).is_some_and(|e| e.supports_release_variants()) {
590 variant_groups
591 .entry(strip_purl_qualifiers(&sr.purl).to_string())
592 .or_default()
593 .push(sr.clone());
594 } else {
595 kept.push(sr.clone());
596 }
597 }
598
599 let mut warnings: Vec<String> = Vec::new();
600
601 let mut multi: Vec<(String, Vec<PatchSearchResult>)> = Vec::new();
604 for (base, variants) in variant_groups {
605 if variants.len() <= 1 {
606 kept.extend(variants);
607 } else {
608 multi.push((base, variants));
609 }
610 }
611
612 if multi.is_empty() {
613 return (kept, warnings);
614 }
615
616 let all_qualified: Vec<String> = multi
622 .iter()
623 .flat_map(|(_, variants)| variants.iter().map(|s| s.purl.clone()))
624 .collect();
625 let partitioned = partition_purls(&all_qualified, None);
627 let crawler_options = CrawlerOptions {
628 cwd: params.cwd.clone(),
629 global: params.global,
630 global_prefix: params.global_prefix.clone(),
631 batch_size: 100,
632 };
633 let paths = find_packages_for_rollback(&partitioned, &crawler_options, true).await;
634
635 for (base, variants) in multi {
636 let pkg_path = variants.iter().find_map(|s| paths.get(&s.purl)).cloned();
639 let Some(pkg_path) = pkg_path else {
640 warnings.push(format!(
643 "{base} is not installed locally; keeping all {} release variant(s).",
644 variants.len()
645 ));
646 kept.extend(variants);
647 continue;
648 };
649
650 let mut candidates: Vec<(String, HashMap<String, PatchFileInfo>)> = Vec::new();
653 for s in &variants {
654 match api_client.fetch_patch(org, &s.uuid).await {
655 Ok(Some(patch)) => {
656 candidates.push((s.purl.clone(), files_for_selection(&patch)));
657 }
658 _ => candidates.push((s.purl.clone(), HashMap::new())),
661 }
662 }
663
664 let refs: Vec<(&str, &HashMap<String, PatchFileInfo>)> = candidates
665 .iter()
666 .map(|(purl, files)| (purl.as_str(), files))
667 .collect();
668
669 let matched = select_installed_variants(&pkg_path, &refs).await;
673 if matched.is_empty() {
674 warnings.push(format!(
678 "No release variant of {base} matches the installed distribution; keeping all {} variant(s).",
679 variants.len()
680 ));
681 kept.extend(variants);
682 } else {
683 let winners: std::collections::HashSet<String> =
684 matched.iter().map(|&i| candidates[i].0.clone()).collect();
685 kept.extend(variants.into_iter().filter(|s| winners.contains(&s.purl)));
686 }
687 }
688
689 (kept, warnings)
690}
691
692fn files_for_selection(patch: &PatchResponse) -> HashMap<String, PatchFileInfo> {
697 let mut files = HashMap::new();
698 for (file_path, file_info) in &patch.files {
699 if let (Some(before), Some(after)) = (&file_info.before_hash, &file_info.after_hash) {
700 files.insert(
701 file_path.clone(),
702 PatchFileInfo {
703 before_hash: before.clone(),
704 after_hash: after.clone(),
705 },
706 );
707 }
708 }
709 files
710}
711
712pub async fn download_and_apply_patches(
716 selected: &[PatchSearchResult],
717 params: &DownloadParams,
718) -> (i32, serde_json::Value) {
719 let mut overrides = params.api_overrides.clone();
720 if overrides.org_slug.is_none() {
721 overrides.org_slug = params.org.clone();
722 }
723 let (api_client, _) =
724 socket_patch_core::api::client::get_api_client_with_overrides(overrides).await;
725 let effective_org: Option<&str> = None;
726
727 let socket_dir = params.cwd.join(".socket");
728 let blobs_dir = socket_dir.join("blobs");
729 let manifest_path = socket_dir.join("manifest.json");
730
731 if let Err(e) = tokio::fs::create_dir_all(&socket_dir).await {
732 let err = format!("Failed to create .socket directory: {}", e);
733 report_error(params.json, &err);
734 return (1, serde_json::json!({"status": "error", "error": err}));
735 }
736 if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await {
737 let err = format!("Failed to create blobs directory: {}", e);
738 report_error(params.json, &err);
739 return (1, serde_json::json!({"status": "error", "error": err}));
740 }
741
742 let mut manifest = match read_manifest(&manifest_path).await {
743 Ok(Some(m)) => m,
744 _ => PatchManifest::new(),
745 };
746
747 let mut narrow_warnings: Vec<String> = Vec::new();
751 let selected_owned: Vec<PatchSearchResult>;
752 let selected: &[PatchSearchResult] = if params.all_releases {
753 selected
754 } else {
755 let (kept, warns) =
756 filter_to_installed_releases(selected, params, &api_client, effective_org).await;
757 if !params.json && !params.silent {
758 for w in &warns {
759 eprintln!(" [note] {w}");
760 }
761 }
762 narrow_warnings = warns;
763 selected_owned = kept;
764 &selected_owned
765 };
766
767 if !params.json && !params.silent {
768 eprintln!("\nDownloading {} patch(es)...", selected.len());
769 }
770
771 let mut patches_added = 0;
772 let mut patches_skipped = 0;
773 let mut patches_failed = 0;
774 let mut downloaded_patches: Vec<serde_json::Value> = Vec::new();
775 let mut updates: Vec<String> = Vec::new();
776
777 for search_result in selected {
778 if let Some(existing) = manifest.patches.get(&search_result.purl) {
780 if existing.uuid != search_result.uuid {
781 updates.push(search_result.purl.clone());
782 if !params.json && !params.silent {
783 eprintln!(
784 " [update] {} (replacing {})",
785 search_result.purl,
786 short_uuid(&existing.uuid)
790 );
791 }
792 }
793 }
794
795 match api_client
796 .fetch_patch(effective_org, &search_result.uuid)
797 .await
798 {
799 Ok(Some(patch)) => {
800 let action = decide_patch_action(&manifest, &patch.purl, &patch.uuid);
804 if let PatchAction::Skipped = action {
805 if !params.json && !params.silent {
806 eprintln!(" [skip] {} (already in manifest)", patch.purl);
807 }
808 downloaded_patches.push(serde_json::json!({
809 "purl": patch.purl,
810 "uuid": patch.uuid,
811 "action": "skipped",
812 }));
813 patches_skipped += 1;
814 continue;
815 }
816
817 let mut files = HashMap::new();
821 for (file_path, file_info) in &patch.files {
822 if let (Some(before), Some(after)) =
823 (&file_info.before_hash, &file_info.after_hash)
824 {
825 files.insert(
826 file_path.clone(),
827 PatchFileInfo {
828 before_hash: before.clone(),
829 after_hash: after.clone(),
830 },
831 );
832 }
833 }
834
835 let quiet = params.json || params.silent;
836 if write_all_patch_blobs(&blobs_dir, &patch, quiet).await.is_err() {
837 patches_failed += 1;
838 downloaded_patches.push(serde_json::json!({
839 "purl": patch.purl,
840 "uuid": patch.uuid,
841 "action": "failed",
842 "error": "Blob decode or write failed",
843 }));
844 continue;
845 }
846
847 manifest
848 .patches
849 .insert(patch.purl.clone(), build_patch_record(&patch, files));
850
851 let mut action_record = match &action {
852 PatchAction::Updated { old_uuid } => {
853 if !params.json && !params.silent {
854 eprintln!(" [update] {}", patch.purl);
855 }
856 serde_json::json!({
857 "purl": patch.purl,
858 "uuid": patch.uuid,
859 "action": "updated",
860 "oldUuid": old_uuid,
861 })
862 }
863 _ => {
864 if !params.json && !params.silent {
865 eprintln!(" [add] {}", patch.purl);
866 }
867 serde_json::json!({
868 "purl": patch.purl,
869 "uuid": patch.uuid,
870 "action": "added",
871 })
872 }
873 };
874 merge_metadata(&mut action_record, patch_event_metadata(&patch));
879 downloaded_patches.push(action_record);
880 patches_added += 1;
881 }
882 Ok(None) => {
883 if !params.json && !params.silent {
884 eprintln!(" [fail] {} (could not fetch details)", search_result.purl);
885 }
886 downloaded_patches.push(serde_json::json!({
887 "purl": search_result.purl,
888 "uuid": search_result.uuid,
889 "action": "failed",
890 "error": "could not fetch details",
891 }));
892 patches_failed += 1;
893 }
894 Err(e) => {
895 if !params.json && !params.silent {
896 eprintln!(" [fail] {} ({e})", search_result.purl);
897 }
898 downloaded_patches.push(serde_json::json!({
899 "purl": search_result.purl,
900 "uuid": search_result.uuid,
901 "action": "failed",
902 "error": e.to_string(),
903 }));
904 patches_failed += 1;
905 }
906 }
907 }
908
909 if let Err(e) = write_manifest(&manifest_path, &manifest).await {
911 let msg = format!("Error writing manifest: {e}");
912 let err_json = serde_json::json!({ "status": "error", "error": &msg });
913 if params.json {
914 print_json(&err_json);
915 } else {
916 eprintln!("{msg}");
917 }
918 return (1, err_json);
919 }
920
921 if !params.json && !params.silent {
922 eprintln!("\nPatches saved to {}", manifest_path.display());
923 eprintln!(" Added: {patches_added}");
924 if patches_skipped > 0 {
925 eprintln!(" Skipped: {patches_skipped}");
926 }
927 if patches_failed > 0 {
928 eprintln!(" Failed: {patches_failed}");
929 }
930 if !updates.is_empty() {
931 eprintln!(" Updated: {}", updates.len());
932 }
933 }
934
935 let mut apply_succeeded = false;
937 if !params.save_only && patches_added > 0 {
938 if !params.json && !params.silent {
939 eprintln!("\nApplying patches...");
940 }
941 let apply_args = super::apply::ApplyArgs {
942 common: crate::args::GlobalArgs {
943 cwd: params.cwd.clone(),
944 manifest_path: manifest_path.display().to_string(),
945 global: params.global,
946 global_prefix: params.global_prefix.clone(),
947 silent: params.json || params.silent,
948 download_mode: params.download_mode.clone(),
949 ..crate::args::GlobalArgs::default()
950 },
951 force: false,
952 vex: Default::default(),
955 };
956 let code = super::apply::run(apply_args).await;
957 apply_succeeded = code == 0;
958 if code != 0 && !params.json && !params.silent {
959 eprintln!("\nSome patches could not be applied.");
960 }
961 }
962
963 let mut result_json = serde_json::json!({
964 "status": if patches_failed > 0 { "partial_failure" } else { "success" },
965 "found": selected.len(),
966 "downloaded": patches_added,
967 "skipped": patches_skipped,
968 "failed": patches_failed,
969 "applied": if apply_succeeded { patches_added } else { 0 },
970 "updated": updates.len(),
971 "patches": downloaded_patches,
972 });
973 if !narrow_warnings.is_empty() {
977 result_json["warnings"] = serde_json::json!(narrow_warnings);
978 }
979
980 let exit_code = if patches_failed > 0 || (!apply_succeeded && patches_added > 0 && !params.save_only) { 1 } else { 0 };
981 (exit_code, result_json)
982}
983
984pub async fn run(args: GetArgs) -> i32 {
985 let type_flags = [args.id, args.cve, args.ghsa, args.package]
987 .iter()
988 .filter(|&&f| f)
989 .count();
990 if type_flags > 1 {
991 report_error(
992 args.common.json,
993 "Only one of --id, --cve, --ghsa, or --package can be specified",
994 );
995 return 1;
996 }
997 if args.one_off && args.save_only {
998 if args.common.json {
999 print_json(&serde_json::json!({
1000 "status": "error",
1001 "error": "--one-off and --save-only cannot be used together",
1002 }));
1003 } else {
1004 eprintln!("Error: --one-off and --save-only cannot be used together");
1005 }
1006 return 1;
1007 }
1008
1009 apply_env_toggles(&args.common);
1010 let overrides = args.common.api_client_overrides();
1011 let (mut api_client, mut use_public_proxy) =
1012 get_api_client_with_overrides(overrides.clone()).await;
1013 let telemetry_token = api_client.api_token().cloned();
1014 let telemetry_org = api_client.org_slug().cloned();
1015 let download_mode = args.common.download_mode.clone();
1016 let mut fallback_to_proxy = false;
1021
1022 let effective_org_slug: Option<&str> = None;
1024
1025 let id_type = if args.id {
1027 IdentifierType::Uuid
1028 } else if args.cve {
1029 IdentifierType::Cve
1030 } else if args.ghsa {
1031 IdentifierType::Ghsa
1032 } else if args.package {
1033 IdentifierType::Package
1034 } else {
1035 match detect_identifier_type(&args.identifier) {
1036 Some(t) => t,
1037 None => {
1038 if !args.common.json {
1039 println!("Treating \"{}\" as a package name search", args.identifier);
1040 }
1041 IdentifierType::Package
1042 }
1043 }
1044 };
1045
1046 if id_type == IdentifierType::Uuid {
1048 if !args.common.json {
1049 println!("Fetching patch by UUID: {}", args.identifier);
1050 }
1051 let mut fetch_result = api_client
1052 .fetch_patch(effective_org_slug, &args.identifier)
1053 .await;
1054 if !use_public_proxy {
1058 if let Err(ref e) = fetch_result {
1059 if is_fallback_candidate(e) {
1060 eprintln!(
1061 "Warning: authenticated API returned {e}; \
1062 falling back to public patch API proxy (free patches only)."
1063 );
1064 api_client = build_proxy_fallback_client(&overrides);
1065 use_public_proxy = true;
1066 fallback_to_proxy = true;
1067 fetch_result = api_client
1068 .fetch_patch(effective_org_slug, &args.identifier)
1069 .await;
1070 }
1071 }
1072 }
1073 match fetch_result {
1074 Ok(Some(patch)) => {
1075 if patch.tier == "paid" && use_public_proxy {
1076 track_patch_fetch_failed(
1077 &patch.uuid,
1078 "paid_required",
1079 fallback_to_proxy,
1080 telemetry_token.as_deref(),
1081 telemetry_org.as_deref(),
1082 )
1083 .await;
1084 if args.common.json {
1085 print_json(&serde_json::json!({
1086 "status": "paid_required",
1087 "found": 1,
1088 "downloaded": 0,
1089 "applied": 0,
1090 "patches": [{
1091 "purl": patch.purl,
1092 "uuid": patch.uuid,
1093 "tier": "paid",
1094 }],
1095 }));
1096 } else {
1097 println!("\nThis patch requires a paid subscription to download.");
1098 println!("\n Patch: {}", patch.purl);
1099 println!(" Tier: paid");
1100 println!("\n Upgrade at: https://socket.dev/pricing\n");
1101 }
1102 return 0;
1103 }
1104
1105 track_patch_fetched(
1111 &patch.uuid,
1112 &patch.tier,
1113 &ecosystem_from_purl(&patch.purl),
1114 &download_mode,
1115 fallback_to_proxy,
1116 telemetry_token.as_deref(),
1117 telemetry_org.as_deref(),
1118 )
1119 .await;
1120 return save_and_apply_patch(&args, &patch.purl, &patch.uuid, effective_org_slug)
1122 .await;
1123 }
1124 Ok(None) => {
1125 track_patch_fetch_failed(
1126 &args.identifier,
1127 "not_found",
1128 fallback_to_proxy,
1129 telemetry_token.as_deref(),
1130 telemetry_org.as_deref(),
1131 )
1132 .await;
1133 if args.common.json {
1134 print_json(&empty_result_json("not_found"));
1135 } else {
1136 println!("No patch found with UUID: {}", args.identifier);
1137 }
1138 return 0;
1139 }
1140 Err(e) => {
1141 return report_fetch_failure(
1142 &args.identifier,
1143 e,
1144 fallback_to_proxy,
1145 telemetry_token.as_deref(),
1146 telemetry_org.as_deref(),
1147 args.common.json,
1148 )
1149 .await;
1150 }
1151 }
1152 }
1153
1154 let search_response: SearchResponse = match id_type {
1158 IdentifierType::Cve | IdentifierType::Ghsa | IdentifierType::Purl => {
1159 if !args.common.json {
1160 let label = match id_type {
1161 IdentifierType::Cve => "CVE",
1162 IdentifierType::Ghsa => "GHSA",
1163 IdentifierType::Purl => "PURL",
1164 _ => unreachable!(),
1165 };
1166 println!("Searching patches for {label}: {}", args.identifier);
1167 }
1168 let result = match id_type {
1169 IdentifierType::Cve => {
1170 api_client
1171 .search_patches_by_cve(effective_org_slug, &args.identifier)
1172 .await
1173 }
1174 IdentifierType::Ghsa => {
1175 api_client
1176 .search_patches_by_ghsa(effective_org_slug, &args.identifier)
1177 .await
1178 }
1179 IdentifierType::Purl => {
1180 api_client
1181 .search_patches_by_package(effective_org_slug, &args.identifier)
1182 .await
1183 }
1184 _ => unreachable!(),
1185 };
1186 match result {
1187 Ok(r) => r,
1188 Err(e) => {
1189 return report_fetch_failure(
1190 &args.identifier,
1191 e,
1192 fallback_to_proxy,
1193 telemetry_token.as_deref(),
1194 telemetry_org.as_deref(),
1195 args.common.json,
1196 )
1197 .await;
1198 }
1199 }
1200 }
1201 IdentifierType::Package => {
1202 if !args.common.json {
1203 println!("Enumerating packages...");
1204 }
1205 let crawler_options = CrawlerOptions {
1206 cwd: args.common.cwd.clone(),
1207 global: args.common.global,
1208 global_prefix: args.common.global_prefix.clone(),
1209 batch_size: 100,
1210 };
1211 let (all_packages, _) = crawl_all_ecosystems(&crawler_options).await;
1212
1213 if all_packages.is_empty() {
1214 if args.common.json {
1215 print_json(&empty_result_json("no_packages"));
1216 } else if args.common.global {
1217 println!("No global packages found.");
1218 } else {
1219 #[allow(unused_mut)]
1220 let mut install_cmds = String::from("npm/yarn/pnpm/pip");
1221 #[cfg(feature = "cargo")]
1222 install_cmds.push_str("/cargo");
1223 #[cfg(feature = "golang")]
1224 install_cmds.push_str("/go");
1225 #[cfg(feature = "maven")]
1226 install_cmds.push_str("/mvn");
1227 #[cfg(feature = "composer")]
1228 install_cmds.push_str("/composer");
1229 println!("No packages found. Run {install_cmds} install first.");
1230 }
1231 return 0;
1232 }
1233
1234 if !args.common.json {
1235 println!("Found {} packages", all_packages.len());
1236 }
1237
1238 let matches = fuzzy_match_packages(&args.identifier, &all_packages, 20);
1239
1240 if matches.is_empty() {
1241 if args.common.json {
1242 print_json(&empty_result_json("no_match"));
1243 } else {
1244 println!("No packages matching \"{}\" found.", args.identifier);
1245 }
1246 return 0;
1247 }
1248
1249 if !args.common.json {
1250 println!(
1251 "Found {} matching package(s), checking for available patches...",
1252 matches.len()
1253 );
1254 }
1255
1256 let best_match = &matches[0];
1258 match api_client
1259 .search_patches_by_package(effective_org_slug, &best_match.purl)
1260 .await
1261 {
1262 Ok(r) => r,
1263 Err(e) => {
1264 return report_fetch_failure(
1265 &args.identifier,
1266 e,
1267 fallback_to_proxy,
1268 telemetry_token.as_deref(),
1269 telemetry_org.as_deref(),
1270 args.common.json,
1271 )
1272 .await;
1273 }
1274 }
1275 }
1276 _ => unreachable!(),
1277 };
1278
1279 if search_response.patches.is_empty() {
1280 if args.common.json {
1281 print_json(&empty_result_json("not_found"));
1282 } else {
1283 println!(
1284 "No patches found for {}: {}",
1285 id_type, args.identifier
1286 );
1287 }
1288 return 0;
1289 }
1290
1291 if !args.common.json {
1292 display_search_results(&search_response.patches, search_response.can_access_paid_patches);
1293 }
1294
1295 let accessible: Vec<_> = search_response
1297 .patches
1298 .iter()
1299 .filter(|p| p.tier == "free" || search_response.can_access_paid_patches)
1300 .cloned()
1301 .collect();
1302
1303 if accessible.is_empty() {
1304 if args.common.json {
1305 print_json(&serde_json::json!({
1306 "status": "paid_required",
1307 "found": search_response.patches.len(),
1308 "downloaded": 0,
1309 "applied": 0,
1310 "patches": search_response.patches.iter().map(|p| serde_json::json!({
1311 "purl": p.purl,
1312 "uuid": p.uuid,
1313 "tier": p.tier,
1314 })).collect::<Vec<_>>(),
1315 }));
1316 } else {
1317 println!("\nAll available patches require a paid subscription.");
1318 println!("\n Upgrade at: https://socket.dev/pricing\n");
1319 }
1320 return 0;
1321 }
1322
1323 let selected = match select_patches(
1325 &accessible,
1326 search_response.can_access_paid_patches,
1327 args.common.json,
1328 ) {
1329 Ok(s) => s,
1330 Err(code) => return code,
1331 };
1332
1333 if selected.is_empty() {
1334 if !args.common.json {
1335 println!("No patches selected.");
1336 }
1337 return 0;
1338 }
1339
1340 let prompt = format!("Download {} patch(es)?", selected.len());
1342 if !confirm(&prompt, true, args.common.yes, args.common.json) {
1343 if !args.common.json {
1344 println!("Download cancelled.");
1345 }
1346 return 0;
1347 }
1348
1349 let params = DownloadParams {
1351 cwd: args.common.cwd.clone(),
1352 org: args.common.org.clone(),
1353 save_only: args.save_only,
1354 one_off: args.one_off,
1355 global: args.common.global,
1356 global_prefix: args.common.global_prefix.clone(),
1357 json: args.common.json,
1358 silent: false,
1359 download_mode: args.common.download_mode.clone(),
1360 api_overrides: args.common.api_client_overrides(),
1361 all_releases: args.all_releases,
1362 };
1363
1364 let (code, result_json) = download_and_apply_patches(&selected, ¶ms).await;
1365
1366 if args.common.json {
1367 println!("{}", serde_json::to_string_pretty(&result_json).unwrap());
1368 }
1369
1370 code
1371}
1372
1373fn display_search_results(patches: &[PatchSearchResult], can_access_paid: bool) {
1374 println!("\nFound patches:\n");
1375
1376 for (i, patch) in patches.iter().enumerate() {
1377 let tier_label = if patch.tier == "paid" {
1378 " [PAID]"
1379 } else {
1380 " [FREE]"
1381 };
1382 let access_label = if patch.tier == "paid" && !can_access_paid {
1383 " (no access)"
1384 } else {
1385 ""
1386 };
1387
1388 println!(" {}. {}{}{}", i + 1, patch.purl, tier_label, access_label);
1389 println!(" UUID: {}", patch.uuid);
1390 if !patch.description.is_empty() {
1391 let desc = truncate_with_ellipsis(&patch.description, 80);
1392 println!(" Description: {desc}");
1393 }
1394
1395 let vuln_ids: Vec<_> = patch.vulnerabilities.keys().collect();
1396 if !vuln_ids.is_empty() {
1397 let vuln_summary: Vec<String> = patch
1398 .vulnerabilities
1399 .iter()
1400 .map(|(id, vuln)| {
1401 let cves = if vuln.cves.is_empty() {
1402 id.to_string()
1403 } else {
1404 vuln.cves.join(", ")
1405 };
1406 format!("{cves} ({})", vuln.severity)
1407 })
1408 .collect();
1409 println!(" Fixes: {}", vuln_summary.join(", "));
1410 }
1411 println!();
1412 }
1413}
1414
1415async fn save_and_apply_patch(
1416 args: &GetArgs,
1417 _purl: &str,
1418 uuid: &str,
1419 _org_slug: Option<&str>,
1420) -> i32 {
1421 let (api_client, _) =
1423 get_api_client_with_overrides(args.common.api_client_overrides()).await;
1424 let effective_org: Option<&str> = None; let patch = match api_client.fetch_patch(effective_org, uuid).await {
1427 Ok(Some(p)) => p,
1428 Ok(None) => {
1429 if args.common.json {
1430 print_json(&empty_result_json("not_found"));
1431 } else {
1432 println!("No patch found with UUID: {uuid}");
1433 }
1434 return 0;
1435 }
1436 Err(e) => {
1437 report_error(args.common.json, e);
1438 return 1;
1439 }
1440 };
1441
1442 let socket_dir = args.common.cwd.join(".socket");
1443 let blobs_dir = socket_dir.join("blobs");
1444 let manifest_path = socket_dir.join("manifest.json");
1445
1446 if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await {
1447 report_error(args.common.json, format!("Failed to create blobs directory: {e}"));
1448 return 1;
1449 }
1450
1451 let mut manifest = match read_manifest(&manifest_path).await {
1452 Ok(Some(m)) => m,
1453 _ => PatchManifest::new(),
1454 };
1455
1456 let mut files = HashMap::new();
1461 for (file_path, file_info) in &patch.files {
1462 if let Some(after) = &file_info.after_hash {
1463 files.insert(
1464 file_path.clone(),
1465 PatchFileInfo {
1466 before_hash: file_info.before_hash.clone().unwrap_or_default(),
1467 after_hash: after.clone(),
1468 },
1469 );
1470 }
1471 }
1472
1473 if write_all_patch_blobs(&blobs_dir, &patch, args.common.json)
1474 .await
1475 .is_err()
1476 {
1477 if args.common.json {
1478 print_json(&serde_json::json!({
1479 "status": "error",
1480 "found": 1,
1481 "downloaded": 0,
1482 "applied": 0,
1483 "error": "Blob decode or write failed",
1484 "patches": [{
1485 "purl": patch.purl,
1486 "uuid": patch.uuid,
1487 "action": "failed",
1488 "error": "Blob decode or write failed",
1489 }],
1490 }));
1491 } else {
1492 eprintln!("Error: Blob decode or write failed for patch {}", patch.purl);
1493 }
1494 return 1;
1495 }
1496
1497 let added = manifest
1498 .patches
1499 .get(&patch.purl)
1500 .is_none_or(|p| p.uuid != patch.uuid);
1501
1502 manifest
1503 .patches
1504 .insert(patch.purl.clone(), build_patch_record(&patch, files));
1505
1506 if let Err(e) = write_manifest(&manifest_path, &manifest).await {
1507 report_error(args.common.json, format!("Error writing manifest: {e}"));
1508 return 1;
1509 }
1510
1511 if !args.common.json {
1512 println!("\nPatch saved to {}", manifest_path.display());
1513 if added {
1514 println!(" Added: 1");
1515 } else {
1516 println!(" Skipped: 1 (already exists)");
1517 }
1518 }
1519
1520 let mut apply_succeeded = false;
1521 if !args.save_only && added {
1522 if !args.common.json {
1523 println!("\nApplying patches...");
1524 }
1525 let apply_args = super::apply::ApplyArgs {
1526 common: crate::args::GlobalArgs {
1527 cwd: args.common.cwd.clone(),
1528 manifest_path: manifest_path.display().to_string(),
1529 global: args.common.global,
1530 global_prefix: args.common.global_prefix.clone(),
1531 silent: args.common.json,
1532 download_mode: args.common.download_mode.clone(),
1533 ..crate::args::GlobalArgs::default()
1534 },
1535 force: false,
1536 vex: Default::default(),
1539 };
1540 let code = super::apply::run(apply_args).await;
1541 apply_succeeded = code == 0;
1542 if code != 0 && !args.common.json {
1543 eprintln!("\nSome patches could not be applied.");
1544 }
1545 }
1546
1547 if args.common.json {
1548 let mut patch_record = serde_json::json!({
1549 "purl": patch.purl,
1550 "uuid": patch.uuid,
1551 "action": if added { "added" } else { "skipped" },
1552 });
1553 if added {
1554 merge_metadata(&mut patch_record, patch_event_metadata(&patch));
1557 }
1558 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1559 "status": "success",
1560 "found": 1,
1561 "downloaded": if added { 1 } else { 0 },
1562 "applied": if apply_succeeded { 1 } else { 0 },
1563 "patches": [patch_record],
1564 })).unwrap());
1565 }
1566
1567 if !apply_succeeded && added && !args.save_only { 1 } else { 0 }
1568}
1569
1570fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
1571 let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1572 let mut table = [255u8; 256];
1573 for (i, &c) in chars.iter().enumerate() {
1574 table[c as usize] = i as u8;
1575 }
1576
1577 let input = input.as_bytes();
1578 let mut output = Vec::with_capacity(input.len() * 3 / 4);
1579
1580 let mut buf = 0u32;
1581 let mut bits = 0u32;
1582
1583 for &b in input {
1584 if b == b'=' || b == b'\n' || b == b'\r' {
1585 continue;
1586 }
1587 let val = table[b as usize];
1588 if val == 255 {
1589 return Err(format!("Invalid base64 character: {}", b as char));
1590 }
1591 buf = (buf << 6) | val as u32;
1592 bits += 6;
1593 if bits >= 8 {
1594 bits -= 8;
1595 output.push((buf >> bits) as u8);
1596 buf &= (1 << bits) - 1;
1597 }
1598 }
1599
1600 Ok(output)
1601}
1602
1603#[cfg(test)]
1604mod tests {
1605 use super::*;
1606 use socket_patch_core::api::types::VulnerabilityResponse;
1607 use std::collections::HashMap;
1608
1609 #[test]
1612 fn detect_uuid_lowercase() {
1613 assert_eq!(
1614 detect_identifier_type("80630680-4da6-45f9-bba8-b888e0ffd58c"),
1615 Some(IdentifierType::Uuid)
1616 );
1617 }
1618
1619 #[test]
1620 fn detect_uuid_uppercase() {
1621 assert_eq!(
1623 detect_identifier_type("80630680-4DA6-45F9-BBA8-B888E0FFD58C"),
1624 Some(IdentifierType::Uuid)
1625 );
1626 }
1627
1628 #[test]
1629 fn detect_cve_uppercase() {
1630 assert_eq!(
1631 detect_identifier_type("CVE-2021-44906"),
1632 Some(IdentifierType::Cve)
1633 );
1634 }
1635
1636 #[test]
1637 fn detect_cve_lowercase() {
1638 assert_eq!(
1640 detect_identifier_type("cve-2021-44906"),
1641 Some(IdentifierType::Cve)
1642 );
1643 }
1644
1645 #[test]
1646 fn detect_ghsa_uppercase() {
1647 assert_eq!(
1648 detect_identifier_type("GHSA-abcd-1234-wxyz"),
1649 Some(IdentifierType::Ghsa)
1650 );
1651 }
1652
1653 #[test]
1654 fn detect_ghsa_lowercase() {
1655 assert_eq!(
1657 detect_identifier_type("ghsa-abcd-1234-wxyz"),
1658 Some(IdentifierType::Ghsa)
1659 );
1660 }
1661
1662 #[test]
1663 fn detect_purl() {
1664 assert_eq!(
1665 detect_identifier_type("pkg:npm/foo@1.0"),
1666 Some(IdentifierType::Purl)
1667 );
1668 }
1669
1670 #[test]
1671 fn detect_package_name_returns_none() {
1672 assert_eq!(detect_identifier_type("minimist"), None);
1675 }
1676
1677 #[test]
1678 fn detect_malformed_cve_returns_none() {
1679 assert_eq!(detect_identifier_type("CVE-not-a-year"), None);
1680 }
1681
1682 #[test]
1683 fn detect_empty_string_returns_none() {
1684 assert_eq!(detect_identifier_type(""), None);
1685 }
1686
1687 fn mk_patch(
1690 uuid: &str,
1691 purl: &str,
1692 tier: &str,
1693 published_at: &str,
1694 ) -> PatchSearchResult {
1695 PatchSearchResult {
1696 uuid: uuid.into(),
1697 purl: purl.into(),
1698 published_at: published_at.into(),
1699 description: format!("desc-{uuid}"),
1700 license: "MIT".into(),
1701 tier: tier.into(),
1702 vulnerabilities: HashMap::<String, VulnerabilityResponse>::new(),
1703 }
1704 }
1705
1706 #[test]
1707 fn select_free_user_one_free_patch_returns_it() {
1708 let patches = vec![mk_patch("u1", "pkg:npm/foo@1.0", "free", "2024-01-01")];
1709 let out = select_patches(&patches, false, false).expect("ok");
1710 assert_eq!(out.len(), 1);
1711 assert_eq!(out[0].uuid, "u1");
1712 }
1713
1714 #[test]
1715 fn select_paid_user_prefers_paid_over_free_same_purl() {
1716 let patches = vec![
1717 mk_patch("free1", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1718 mk_patch("paid1", "pkg:npm/foo@1.0", "paid", "2024-01-01"),
1719 ];
1720 let out = select_patches(&patches, true, false).expect("ok");
1721 assert_eq!(out.len(), 1);
1722 assert_eq!(out[0].uuid, "paid1");
1724 assert_eq!(out[0].tier, "paid");
1725 }
1726
1727 #[test]
1728 fn select_paid_user_picks_most_recent_paid() {
1729 let patches = vec![
1730 mk_patch("old", "pkg:npm/foo@1.0", "paid", "2024-01-01"),
1731 mk_patch("new", "pkg:npm/foo@1.0", "paid", "2024-06-01"),
1732 ];
1733 let out = select_patches(&patches, true, false).expect("ok");
1734 assert_eq!(out.len(), 1);
1735 assert_eq!(out[0].uuid, "new");
1736 }
1737
1738 #[test]
1739 fn select_paid_user_falls_back_to_most_recent_free_when_no_paid() {
1740 let patches = vec![
1741 mk_patch("old", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1742 mk_patch("new", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1743 ];
1744 let out = select_patches(&patches, true, false).expect("ok");
1745 assert_eq!(out.len(), 1);
1746 assert_eq!(out[0].uuid, "new");
1747 }
1748
1749 #[test]
1750 fn select_free_user_multi_free_json_mode_errors() {
1751 let patches = vec![
1754 mk_patch("a", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1755 mk_patch("b", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1756 ];
1757 let err = select_patches(&patches, false, true).expect_err("should fail");
1758 assert_eq!(err, 1);
1759 }
1760
1761 #[test]
1762 fn select_empty_input_returns_empty() {
1763 let out = select_patches(&[], false, false).expect("ok");
1764 assert!(out.is_empty());
1765 let out = select_patches(&[], true, false).expect("ok");
1766 assert!(out.is_empty());
1767 let out = select_patches(&[], false, true).expect("ok");
1768 assert!(out.is_empty());
1769 }
1770
1771 #[test]
1772 fn select_free_user_paid_filtered_out_then_single_free_auto_selects() {
1773 let patches = vec![
1777 mk_patch("paid", "pkg:npm/foo@1.0", "paid", "2024-06-01"),
1778 mk_patch("free", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1779 ];
1780 let out = select_patches(&patches, false, false).expect("ok");
1781 assert_eq!(out.len(), 1);
1782 assert_eq!(out[0].uuid, "free");
1783 assert_eq!(out[0].tier, "free");
1784 }
1785
1786 fn manifest_with_entry(purl: &str, uuid: &str) -> PatchManifest {
1791 let mut m = PatchManifest::new();
1792 m.patches.insert(
1793 purl.to_string(),
1794 PatchRecord {
1795 uuid: uuid.to_string(),
1796 exported_at: String::new(),
1797 files: HashMap::new(),
1798 vulnerabilities: HashMap::new(),
1799 description: String::new(),
1800 license: String::new(),
1801 tier: "free".to_string(),
1802 },
1803 );
1804 m
1805 }
1806
1807 #[test]
1808 fn decide_patch_action_added_when_purl_absent() {
1809 let manifest = PatchManifest::new();
1810 assert_eq!(
1811 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-a"),
1812 PatchAction::Added,
1813 );
1814 }
1815
1816 #[test]
1817 fn decide_patch_action_skipped_when_same_uuid() {
1818 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1819 assert_eq!(
1820 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-a"),
1821 PatchAction::Skipped,
1822 );
1823 }
1824
1825 #[test]
1826 fn decide_patch_action_updated_when_different_uuid() {
1827 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1828 assert_eq!(
1829 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-b"),
1830 PatchAction::Updated {
1831 old_uuid: "uuid-a".to_string()
1832 },
1833 );
1834 }
1835
1836 #[test]
1837 fn decide_patch_action_added_for_different_purl_even_with_overlapping_manifest() {
1838 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1842 assert_eq!(
1843 decide_patch_action(&manifest, "pkg:npm/bar@2.0", "uuid-a"),
1844 PatchAction::Added,
1845 );
1846 }
1847
1848 #[test]
1855 fn severity_rank_orders_canonical_labels() {
1856 assert!(severity_rank("critical") > severity_rank("high"));
1857 assert!(severity_rank("high") > severity_rank("medium"));
1858 assert!(severity_rank("medium") > severity_rank("low"));
1859 assert_eq!(severity_rank("moderate"), severity_rank("medium"));
1861 assert!(severity_rank("low") > severity_rank(""));
1863 assert!(severity_rank("low") > severity_rank("unknown"));
1864 }
1865
1866 #[test]
1867 fn max_vuln_severity_picks_highest() {
1868 let mut vulns = HashMap::new();
1869 vulns.insert(
1870 "GHSA-low".into(),
1871 VulnerabilityResponse {
1872 cves: vec!["CVE-low".into()],
1873 summary: String::new(),
1874 severity: "low".into(),
1875 description: String::new(),
1876 },
1877 );
1878 vulns.insert(
1879 "GHSA-crit".into(),
1880 VulnerabilityResponse {
1881 cves: vec!["CVE-crit".into()],
1882 summary: String::new(),
1883 severity: "critical".into(),
1884 description: String::new(),
1885 },
1886 );
1887 vulns.insert(
1888 "GHSA-mod".into(),
1889 VulnerabilityResponse {
1890 cves: vec!["CVE-mod".into()],
1891 summary: String::new(),
1892 severity: "moderate".into(),
1893 description: String::new(),
1894 },
1895 );
1896 assert_eq!(max_vuln_severity(&vulns).as_deref(), Some("critical"));
1897 }
1898
1899 #[test]
1900 fn max_vuln_severity_returns_none_for_empty() {
1901 assert_eq!(max_vuln_severity(&HashMap::new()), None);
1902 }
1903
1904 #[test]
1905 fn patch_event_metadata_includes_all_keys() {
1906 let mut vulns = HashMap::new();
1907 vulns.insert(
1908 "GHSA-aaaa-bbbb-cccc".into(),
1909 VulnerabilityResponse {
1910 cves: vec!["CVE-2024-12345".into()],
1911 summary: "Prototype Pollution".into(),
1912 severity: "high".into(),
1913 description: "merge() does not check Object.prototype".into(),
1914 },
1915 );
1916 let patch = PatchResponse {
1917 uuid: "11111111-1111-4111-8111-111111111111".into(),
1918 purl: "pkg:npm/minimist@1.2.2".into(),
1919 published_at: "2024-01-01T00:00:00Z".into(),
1920 files: HashMap::new(),
1921 vulnerabilities: vulns,
1922 description: "Fixes prototype pollution in minimist".into(),
1923 license: "MIT".into(),
1924 tier: "free".into(),
1925 };
1926 let meta = patch_event_metadata(&patch);
1927 assert_eq!(meta["description"], "Fixes prototype pollution in minimist");
1928 assert_eq!(meta["license"], "MIT");
1929 assert_eq!(meta["tier"], "free");
1930 assert_eq!(meta["exportedAt"], "2024-01-01T00:00:00Z");
1931 assert_eq!(meta["severity"], "high");
1932 let vulns_out = meta["vulnerabilities"].as_array().unwrap();
1933 assert_eq!(vulns_out.len(), 1);
1934 assert_eq!(vulns_out[0]["id"], "GHSA-aaaa-bbbb-cccc");
1935 assert_eq!(vulns_out[0]["cves"][0], "CVE-2024-12345");
1936 assert_eq!(vulns_out[0]["severity"], "high");
1937 assert_eq!(vulns_out[0]["summary"], "Prototype Pollution");
1938 }
1939
1940 #[test]
1941 fn patch_event_metadata_sorts_vulnerabilities_by_id() {
1942 let mut vulns = HashMap::new();
1946 for id in ["GHSA-zzz", "GHSA-aaa", "GHSA-mmm"] {
1947 vulns.insert(
1948 id.into(),
1949 VulnerabilityResponse {
1950 cves: Vec::new(),
1951 summary: String::new(),
1952 severity: "low".into(),
1953 description: String::new(),
1954 },
1955 );
1956 }
1957 let patch = PatchResponse {
1958 uuid: String::new(),
1959 purl: String::new(),
1960 published_at: String::new(),
1961 files: HashMap::new(),
1962 vulnerabilities: vulns,
1963 description: String::new(),
1964 license: String::new(),
1965 tier: String::new(),
1966 };
1967 let meta = patch_event_metadata(&patch);
1968 let ids: Vec<&str> = meta["vulnerabilities"]
1969 .as_array()
1970 .unwrap()
1971 .iter()
1972 .map(|v| v["id"].as_str().unwrap())
1973 .collect();
1974 assert_eq!(ids, ["GHSA-aaa", "GHSA-mmm", "GHSA-zzz"]);
1975 }
1976
1977 #[test]
1978 fn patch_event_metadata_omits_severity_when_no_vulns() {
1979 let patch = PatchResponse {
1980 uuid: String::new(),
1981 purl: String::new(),
1982 published_at: "ts".into(),
1983 files: HashMap::new(),
1984 vulnerabilities: HashMap::new(),
1985 description: "desc".into(),
1986 license: "MIT".into(),
1987 tier: "free".into(),
1988 };
1989 let meta = patch_event_metadata(&patch);
1990 assert!(meta.as_object().unwrap().get("severity").is_none());
1994 assert_eq!(meta["vulnerabilities"].as_array().unwrap().len(), 0);
1997 }
1998
1999 #[test]
2005 fn truncate_short_string_unchanged() {
2006 assert_eq!(truncate_with_ellipsis("hello", 60), "hello");
2007 }
2008
2009 #[test]
2010 fn truncate_at_limit_unchanged() {
2011 let s = "a".repeat(60);
2012 assert_eq!(truncate_with_ellipsis(&s, 60), s);
2013 }
2014
2015 #[test]
2016 fn truncate_long_ascii_adds_ellipsis_and_respects_limit() {
2017 let s = "a".repeat(100);
2018 let out = truncate_with_ellipsis(&s, 60);
2019 assert_eq!(out.chars().count(), 60);
2021 assert!(out.ends_with("..."));
2022 assert_eq!(out, format!("{}...", "a".repeat(57)));
2023 }
2024
2025 #[test]
2026 fn truncate_multibyte_does_not_panic_and_is_char_safe() {
2027 let s = "日".repeat(30);
2032 let out = truncate_with_ellipsis(&s, 80);
2033 assert_eq!(out, s);
2034 }
2035
2036 #[test]
2037 fn truncate_multibyte_long_truncates_on_char_boundary() {
2038 let s = "é".repeat(100);
2041 let out = truncate_with_ellipsis(&s, 80);
2042 assert_eq!(out.chars().count(), 80);
2043 assert!(out.ends_with("..."));
2044 assert_eq!(out, format!("{}...", "é".repeat(77)));
2045 }
2046
2047 #[test]
2053 fn short_uuid_truncates_normal_uuid() {
2054 assert_eq!(short_uuid("80630680-4da6-45f9-bba8-b888e0ffd58c"), "80630680");
2055 }
2056
2057 #[test]
2058 fn short_uuid_returns_whole_string_when_shorter_than_eight() {
2059 assert_eq!(short_uuid("abc"), "abc");
2061 assert_eq!(short_uuid(""), "");
2062 }
2063
2064 #[test]
2065 fn short_uuid_does_not_panic_on_multibyte_boundary() {
2066 let s = "ab€cd"; assert_eq!(short_uuid(s), s);
2072 let s2 = "abcdef€"; assert_eq!(short_uuid(s2), s2);
2075 }
2076}