1use clap::Args;
2use regex::Regex;
3use socket_patch_core::api::client::{
4 build_proxy_fallback_client, get_api_client_with_overrides, is_fallback_candidate,
5};
6use socket_patch_core::api::types::{
7 PatchResponse, PatchSearchResult, SearchResponse, VulnerabilityResponse,
8};
9use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem};
10use socket_patch_core::manifest::operations::{read_manifest, write_manifest};
11use socket_patch_core::manifest::schema::{
12 PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo,
13};
14use socket_patch_core::patch::apply::select_installed_variants;
15use socket_patch_core::utils::fuzzy_match::fuzzy_match_packages;
16use socket_patch_core::utils::purl::{is_purl, strip_purl_qualifiers};
17use socket_patch_core::utils::telemetry::{track_patch_fetch_failed, track_patch_fetched};
18use std::collections::HashMap;
19use std::fmt;
20use std::path::{Path, PathBuf};
21
22use crate::args::{apply_env_toggles, GlobalArgs};
23use crate::ecosystem_dispatch::{crawl_all_ecosystems, find_packages_for_rollback, partition_purls};
24use crate::output::{confirm, select_one, SelectError};
25
26fn ecosystem_from_purl(purl: &str) -> String {
31 purl.strip_prefix("pkg:")
32 .and_then(|rest| rest.split('/').next())
33 .unwrap_or("")
34 .to_string()
35}
36
37#[derive(Debug, PartialEq, Eq, Clone)]
41pub(crate) enum PatchAction {
42 Added,
44 Updated { old_uuid: String },
47 Skipped,
49}
50
51pub(crate) fn decide_patch_action(
54 manifest: &PatchManifest,
55 purl: &str,
56 new_uuid: &str,
57) -> PatchAction {
58 match manifest.patches.get(purl) {
59 Some(existing) if existing.uuid == new_uuid => PatchAction::Skipped,
60 Some(existing) => PatchAction::Updated {
61 old_uuid: existing.uuid.clone(),
62 },
63 None => PatchAction::Added,
64 }
65}
66
67pub(crate) fn severity_rank(severity: &str) -> u8 {
71 match severity.to_ascii_lowercase().as_str() {
72 "critical" => 4,
73 "high" => 3,
74 "moderate" | "medium" => 2,
76 "low" => 1,
77 _ => 0,
78 }
79}
80
81pub(crate) fn max_vuln_severity(
85 vulns: &HashMap<String, VulnerabilityResponse>,
86) -> Option<String> {
87 vulns
88 .values()
89 .max_by_key(|v| severity_rank(&v.severity))
90 .map(|v| v.severity.clone())
91}
92
93pub(crate) fn patch_event_metadata(patch: &PatchResponse) -> serde_json::Value {
104 let mut vulns: Vec<serde_json::Value> = patch
105 .vulnerabilities
106 .iter()
107 .map(|(id, v)| {
108 serde_json::json!({
109 "id": id,
110 "cves": v.cves,
111 "severity": v.severity,
112 "summary": v.summary,
113 "description": v.description,
114 })
115 })
116 .collect();
117 vulns.sort_by(|a, b| {
120 a["id"]
121 .as_str()
122 .unwrap_or("")
123 .cmp(b["id"].as_str().unwrap_or(""))
124 });
125
126 let mut meta = serde_json::Map::new();
127 meta.insert(
128 "description".into(),
129 serde_json::Value::String(patch.description.clone()),
130 );
131 meta.insert(
132 "license".into(),
133 serde_json::Value::String(patch.license.clone()),
134 );
135 meta.insert(
136 "tier".into(),
137 serde_json::Value::String(patch.tier.clone()),
138 );
139 meta.insert(
140 "exportedAt".into(),
141 serde_json::Value::String(patch.published_at.clone()),
142 );
143 if let Some(sev) = max_vuln_severity(&patch.vulnerabilities) {
144 meta.insert("severity".into(), serde_json::Value::String(sev));
145 }
146 meta.insert("vulnerabilities".into(), serde_json::Value::Array(vulns));
147 serde_json::Value::Object(meta)
148}
149
150fn merge_metadata(record: &mut serde_json::Value, meta: serde_json::Value) {
154 if let (Some(record_obj), serde_json::Value::Object(meta_obj)) =
155 (record.as_object_mut(), meta)
156 {
157 for (k, v) in meta_obj {
158 record_obj.insert(k, v);
159 }
160 }
161}
162
163fn print_json(v: &serde_json::Value) {
165 println!("{}", serde_json::to_string_pretty(v).unwrap());
166}
167
168pub(crate) fn truncate_with_ellipsis(s: &str, limit: usize) -> String {
175 if s.chars().count() <= limit {
176 s.to_string()
177 } else {
178 let head: String = s.chars().take(limit.saturating_sub(3)).collect();
179 format!("{head}...")
180 }
181}
182
183fn short_uuid(uuid: &str) -> &str {
189 uuid.get(..8).unwrap_or(uuid)
190}
191
192fn empty_result_json(status: &str) -> serde_json::Value {
196 serde_json::json!({
197 "status": status,
198 "found": 0,
199 "downloaded": 0,
200 "applied": 0,
201 "patches": [],
202 })
203}
204
205async fn report_fetch_failure(
209 identifier: &str,
210 error: impl std::fmt::Display,
211 fallback_to_proxy: bool,
212 api_token: Option<&str>,
213 org_slug: Option<&str>,
214 json: bool,
215) -> i32 {
216 let msg = error.to_string();
217 track_patch_fetch_failed(identifier, &msg, fallback_to_proxy, api_token, org_slug).await;
218 report_error(json, msg);
219 1
220}
221
222fn report_error(json: bool, message: impl std::fmt::Display) {
225 let message = message.to_string();
226 if json {
227 print_json(&serde_json::json!({"status": "error", "error": message}));
228 } else {
229 eprintln!("Error: {message}");
230 }
231}
232
233async fn write_blob_entry(
236 blobs_dir: &Path,
237 b64: &str,
238 hash: &str,
239 file_path: &str,
240 label: &str,
241) -> Result<(), String> {
242 let decoded = base64_decode(b64)
243 .map_err(|e| format!("Failed to decode {label} for {file_path}: {e}"))?;
244 tokio::fs::write(blobs_dir.join(hash), &decoded)
245 .await
246 .map_err(|e| format!("Failed to write {label} for {file_path}: {e}"))
247}
248
249async fn write_all_patch_blobs(
253 blobs_dir: &Path,
254 patch: &PatchResponse,
255 quiet: bool,
256) -> Result<(), ()> {
257 for (file_path, file_info) in &patch.files {
258 if let (Some(blob), Some(hash)) =
259 (&file_info.blob_content, &file_info.after_hash)
260 {
261 if let Err(e) = write_blob_entry(blobs_dir, blob, hash, file_path, "blob").await {
262 if !quiet {
263 eprintln!(" [error] {e}");
264 }
265 return Err(());
266 }
267 }
268 if let (Some(blob), Some(hash)) =
269 (&file_info.before_blob_content, &file_info.before_hash)
270 {
271 if let Err(e) =
272 write_blob_entry(blobs_dir, blob, hash, file_path, "before-blob").await
273 {
274 if !quiet {
275 eprintln!(" [error] {e}");
276 }
277 return Err(());
278 }
279 }
280 }
281 Ok(())
282}
283
284fn vulnerabilities_for_manifest(
287 vulns: &HashMap<String, VulnerabilityResponse>,
288) -> HashMap<String, VulnerabilityInfo> {
289 vulns
290 .iter()
291 .map(|(id, v)| {
292 (
293 id.clone(),
294 VulnerabilityInfo {
295 cves: v.cves.clone(),
296 summary: v.summary.clone(),
297 severity: v.severity.clone(),
298 description: v.description.clone(),
299 },
300 )
301 })
302 .collect()
303}
304
305fn build_patch_record(
310 patch: &PatchResponse,
311 files: HashMap<String, PatchFileInfo>,
312) -> PatchRecord {
313 PatchRecord {
314 uuid: patch.uuid.clone(),
315 exported_at: patch.published_at.clone(),
316 files,
317 vulnerabilities: vulnerabilities_for_manifest(&patch.vulnerabilities),
318 description: patch.description.clone(),
319 license: patch.license.clone(),
320 tier: patch.tier.clone(),
321 }
322}
323
324#[derive(Args)]
325pub struct GetArgs {
326 pub identifier: String,
328
329 #[command(flatten)]
330 pub common: GlobalArgs,
331
332 #[arg(long, default_value_t = false)]
334 pub id: bool,
335
336 #[arg(long, default_value_t = false)]
338 pub cve: bool,
339
340 #[arg(long, default_value_t = false)]
342 pub ghsa: bool,
343
344 #[arg(short = 'p', long = "package", default_value_t = false)]
346 pub package: bool,
347
348 #[arg(long = "save-only", alias = "no-apply", env = "SOCKET_SAVE_ONLY", default_value_t = false)]
350 pub save_only: bool,
351
352 #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
354 pub one_off: bool,
355
356 #[arg(
363 long = "all-releases",
364 env = "SOCKET_ALL_RELEASES",
365 default_value_t = false,
366 value_parser = clap::builder::BoolishValueParser::new(),
367 )]
368 pub all_releases: bool,
369}
370
371#[derive(Debug, Clone, Copy, PartialEq)]
372enum IdentifierType {
373 Uuid,
374 Cve,
375 Ghsa,
376 Purl,
377 Package,
378}
379
380impl fmt::Display for IdentifierType {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 match self {
383 IdentifierType::Uuid => write!(f, "UUID"),
384 IdentifierType::Cve => write!(f, "CVE"),
385 IdentifierType::Ghsa => write!(f, "GHSA"),
386 IdentifierType::Purl => write!(f, "PURL"),
387 IdentifierType::Package => write!(f, "package name"),
388 }
389 }
390}
391
392fn detect_identifier_type(identifier: &str) -> Option<IdentifierType> {
393 let uuid_re = Regex::new(r"(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$").unwrap();
394 let cve_re = Regex::new(r"(?i)^CVE-\d{4}-\d+$").unwrap();
395 let ghsa_re = Regex::new(r"(?i)^GHSA-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}$").unwrap();
396
397 if uuid_re.is_match(identifier) {
398 Some(IdentifierType::Uuid)
399 } else if cve_re.is_match(identifier) {
400 Some(IdentifierType::Cve)
401 } else if ghsa_re.is_match(identifier) {
402 Some(IdentifierType::Ghsa)
403 } else if is_purl(identifier) {
404 Some(IdentifierType::Purl)
405 } else {
406 None
407 }
408}
409
410pub fn select_patches(
419 patches: &[PatchSearchResult],
420 can_access_paid: bool,
421 is_json: bool,
422) -> Result<Vec<PatchSearchResult>, i32> {
423 let mut by_purl: HashMap<String, Vec<&PatchSearchResult>> = HashMap::new();
425 for p in patches {
426 if p.tier == "free" || can_access_paid {
427 by_purl.entry(p.purl.clone()).or_default().push(p);
428 }
429 }
430
431 let mut selected = Vec::new();
432
433 for (purl, mut group) in by_purl {
434 group.sort_by(|a, b| b.published_at.cmp(&a.published_at));
436
437 if can_access_paid {
438 let choice = group
440 .iter()
441 .find(|p| p.tier == "paid")
442 .or_else(|| group.first())
443 .unwrap();
444 selected.push((*choice).clone());
445 } else if group.len() == 1 {
446 selected.push(group[0].clone());
447 } else {
448 let options: Vec<String> = group
450 .iter()
451 .map(|p| {
452 let vuln_summary: Vec<String> = p
453 .vulnerabilities
454 .iter()
455 .map(|(id, v)| {
456 if v.cves.is_empty() {
457 id.clone()
458 } else {
459 v.cves.join(", ")
460 }
461 })
462 .collect();
463 let vulns = if vuln_summary.is_empty() {
464 String::new()
465 } else {
466 format!(" (fixes: {})", vuln_summary.join(", "))
467 };
468 let desc = truncate_with_ellipsis(&p.description, 60);
469 format!("{} [{}]{} - {}", p.uuid, p.tier, vulns, desc)
470 })
471 .collect();
472
473 match select_one(
474 &format!("Multiple patches available for {purl}. Select one:"),
475 &options,
476 is_json,
477 ) {
478 Ok(idx) => {
479 selected.push(group[idx].clone());
480 }
481 Err(SelectError::JsonModeNeedsExplicit) => {
482 let options_json: Vec<serde_json::Value> = group
483 .iter()
484 .map(|p| {
485 let vulns: Vec<serde_json::Value> = p
486 .vulnerabilities
487 .iter()
488 .map(|(id, v)| {
489 serde_json::json!({
490 "id": id,
491 "cves": v.cves,
492 "severity": v.severity,
493 "summary": v.summary,
494 })
495 })
496 .collect();
497 serde_json::json!({
498 "uuid": p.uuid,
499 "tier": p.tier,
500 "published_at": p.published_at,
501 "description": p.description,
502 "vulnerabilities": vulns,
503 })
504 })
505 .collect();
506 println!(
507 "{}",
508 serde_json::to_string_pretty(&serde_json::json!({
509 "status": "selection_required",
510 "error": format!("Multiple patches available for {purl}. Specify --id <UUID> to select one."),
511 "purl": purl,
512 "options": options_json,
513 }))
514 .unwrap()
515 );
516 return Err(1);
517 }
518 Err(SelectError::Cancelled) => {
519 eprintln!("Selection cancelled.");
520 return Err(0);
521 }
522 }
523 }
524 }
525
526 Ok(selected)
527}
528
529pub struct DownloadParams {
531 pub cwd: PathBuf,
532 pub org: Option<String>,
533 pub save_only: bool,
534 pub one_off: bool,
535 pub global: bool,
536 pub global_prefix: Option<PathBuf>,
537 pub json: bool,
538 pub silent: bool,
539 pub download_mode: String,
541 pub api_overrides: socket_patch_core::api::client::ApiClientEnvOverrides,
546 pub all_releases: bool,
552}
553
554async fn filter_to_installed_releases(
577 selected: &[PatchSearchResult],
578 params: &DownloadParams,
579 api_client: &socket_patch_core::api::client::ApiClient,
580 org: Option<&str>,
581) -> (Vec<PatchSearchResult>, Vec<String>) {
582 let mut variant_groups: HashMap<String, Vec<PatchSearchResult>> = HashMap::new();
587 let mut kept: Vec<PatchSearchResult> = Vec::new();
588 for sr in selected {
589 if Ecosystem::from_purl(&sr.purl).is_some_and(|e| e.supports_release_variants()) {
590 variant_groups
591 .entry(strip_purl_qualifiers(&sr.purl).to_string())
592 .or_default()
593 .push(sr.clone());
594 } else {
595 kept.push(sr.clone());
596 }
597 }
598
599 let mut warnings: Vec<String> = Vec::new();
600
601 let mut multi: Vec<(String, Vec<PatchSearchResult>)> = Vec::new();
604 for (base, variants) in variant_groups {
605 if variants.len() <= 1 {
606 kept.extend(variants);
607 } else {
608 multi.push((base, variants));
609 }
610 }
611
612 if multi.is_empty() {
613 return (kept, warnings);
614 }
615
616 let all_qualified: Vec<String> = multi
622 .iter()
623 .flat_map(|(_, variants)| variants.iter().map(|s| s.purl.clone()))
624 .collect();
625 let partitioned = partition_purls(&all_qualified, None);
627 let crawler_options = CrawlerOptions {
628 cwd: params.cwd.clone(),
629 global: params.global,
630 global_prefix: params.global_prefix.clone(),
631 batch_size: 100,
632 };
633 let paths = find_packages_for_rollback(&partitioned, &crawler_options, true).await;
634
635 for (base, variants) in multi {
636 let pkg_path = variants.iter().find_map(|s| paths.get(&s.purl)).cloned();
639 let Some(pkg_path) = pkg_path else {
640 warnings.push(format!(
643 "{base} is not installed locally; keeping all {} release variant(s).",
644 variants.len()
645 ));
646 kept.extend(variants);
647 continue;
648 };
649
650 let mut candidates: Vec<(String, HashMap<String, PatchFileInfo>)> = Vec::new();
653 for s in &variants {
654 match api_client.fetch_patch(org, &s.uuid).await {
655 Ok(Some(patch)) => {
656 candidates.push((s.purl.clone(), files_for_selection(&patch)));
657 }
658 _ => candidates.push((s.purl.clone(), HashMap::new())),
661 }
662 }
663
664 let refs: Vec<(&str, &HashMap<String, PatchFileInfo>)> = candidates
665 .iter()
666 .map(|(purl, files)| (purl.as_str(), files))
667 .collect();
668
669 let matched = select_installed_variants(&pkg_path, &refs).await;
673 if matched.is_empty() {
674 warnings.push(format!(
678 "No release variant of {base} matches the installed distribution; keeping all {} variant(s).",
679 variants.len()
680 ));
681 kept.extend(variants);
682 } else {
683 let winners: std::collections::HashSet<String> =
684 matched.iter().map(|&i| candidates[i].0.clone()).collect();
685 kept.extend(variants.into_iter().filter(|s| winners.contains(&s.purl)));
686 }
687 }
688
689 (kept, warnings)
690}
691
692fn files_for_selection(patch: &PatchResponse) -> HashMap<String, PatchFileInfo> {
697 let mut files = HashMap::new();
698 for (file_path, file_info) in &patch.files {
699 if let (Some(before), Some(after)) = (&file_info.before_hash, &file_info.after_hash) {
700 files.insert(
701 file_path.clone(),
702 PatchFileInfo {
703 before_hash: before.clone(),
704 after_hash: after.clone(),
705 },
706 );
707 }
708 }
709 files
710}
711
712pub async fn download_and_apply_patches(
716 selected: &[PatchSearchResult],
717 params: &DownloadParams,
718) -> (i32, serde_json::Value) {
719 let mut overrides = params.api_overrides.clone();
720 if overrides.org_slug.is_none() {
721 overrides.org_slug = params.org.clone();
722 }
723 let (api_client, _) =
724 socket_patch_core::api::client::get_api_client_with_overrides(overrides).await;
725 let effective_org: Option<&str> = None;
726
727 let socket_dir = params.cwd.join(".socket");
728 let blobs_dir = socket_dir.join("blobs");
729 let manifest_path = socket_dir.join("manifest.json");
730
731 if let Err(e) = tokio::fs::create_dir_all(&socket_dir).await {
732 let err = format!("Failed to create .socket directory: {}", e);
733 report_error(params.json, &err);
734 return (1, serde_json::json!({"status": "error", "error": err}));
735 }
736 if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await {
737 let err = format!("Failed to create blobs directory: {}", e);
738 report_error(params.json, &err);
739 return (1, serde_json::json!({"status": "error", "error": err}));
740 }
741
742 let mut manifest = match read_manifest(&manifest_path).await {
743 Ok(Some(m)) => m,
744 _ => PatchManifest::new(),
745 };
746
747 let mut narrow_warnings: Vec<String> = Vec::new();
751 let selected_owned: Vec<PatchSearchResult>;
752 let selected: &[PatchSearchResult] = if params.all_releases {
753 selected
754 } else {
755 let (kept, warns) =
756 filter_to_installed_releases(selected, params, &api_client, effective_org).await;
757 if !params.json && !params.silent {
758 for w in &warns {
759 eprintln!(" [note] {w}");
760 }
761 }
762 narrow_warnings = warns;
763 selected_owned = kept;
764 &selected_owned
765 };
766
767 if !params.json && !params.silent {
768 eprintln!("\nDownloading {} patch(es)...", selected.len());
769 }
770
771 let mut patches_added = 0;
772 let mut patches_skipped = 0;
773 let mut patches_failed = 0;
774 let mut downloaded_patches: Vec<serde_json::Value> = Vec::new();
775 let mut updates: Vec<String> = Vec::new();
776
777 for search_result in selected {
778 if let Some(existing) = manifest.patches.get(&search_result.purl) {
780 if existing.uuid != search_result.uuid {
781 updates.push(search_result.purl.clone());
782 if !params.json && !params.silent {
783 eprintln!(
784 " [update] {} (replacing {})",
785 search_result.purl,
786 short_uuid(&existing.uuid)
790 );
791 }
792 }
793 }
794
795 match api_client
796 .fetch_patch(effective_org, &search_result.uuid)
797 .await
798 {
799 Ok(Some(patch)) => {
800 let action = decide_patch_action(&manifest, &patch.purl, &patch.uuid);
804 if let PatchAction::Skipped = action {
805 if !params.json && !params.silent {
806 eprintln!(" [skip] {} (already in manifest)", patch.purl);
807 }
808 downloaded_patches.push(serde_json::json!({
809 "purl": patch.purl,
810 "uuid": patch.uuid,
811 "action": "skipped",
812 }));
813 patches_skipped += 1;
814 continue;
815 }
816
817 let mut files = HashMap::new();
821 for (file_path, file_info) in &patch.files {
822 if let (Some(before), Some(after)) =
823 (&file_info.before_hash, &file_info.after_hash)
824 {
825 files.insert(
826 file_path.clone(),
827 PatchFileInfo {
828 before_hash: before.clone(),
829 after_hash: after.clone(),
830 },
831 );
832 }
833 }
834
835 let quiet = params.json || params.silent;
836 if write_all_patch_blobs(&blobs_dir, &patch, quiet).await.is_err() {
837 patches_failed += 1;
838 downloaded_patches.push(serde_json::json!({
839 "purl": patch.purl,
840 "uuid": patch.uuid,
841 "action": "failed",
842 "error": "Blob decode or write failed",
843 }));
844 continue;
845 }
846
847 manifest
848 .patches
849 .insert(patch.purl.clone(), build_patch_record(&patch, files));
850
851 let mut action_record = match &action {
852 PatchAction::Updated { old_uuid } => {
853 if !params.json && !params.silent {
854 eprintln!(" [update] {}", patch.purl);
855 }
856 serde_json::json!({
857 "purl": patch.purl,
858 "uuid": patch.uuid,
859 "action": "updated",
860 "oldUuid": old_uuid,
861 })
862 }
863 _ => {
864 if !params.json && !params.silent {
865 eprintln!(" [add] {}", patch.purl);
866 }
867 serde_json::json!({
868 "purl": patch.purl,
869 "uuid": patch.uuid,
870 "action": "added",
871 })
872 }
873 };
874 merge_metadata(&mut action_record, patch_event_metadata(&patch));
879 downloaded_patches.push(action_record);
880 patches_added += 1;
881 }
882 Ok(None) => {
883 if !params.json && !params.silent {
884 eprintln!(" [fail] {} (could not fetch details)", search_result.purl);
885 }
886 downloaded_patches.push(serde_json::json!({
887 "purl": search_result.purl,
888 "uuid": search_result.uuid,
889 "action": "failed",
890 "error": "could not fetch details",
891 }));
892 patches_failed += 1;
893 }
894 Err(e) => {
895 if !params.json && !params.silent {
896 eprintln!(" [fail] {} ({e})", search_result.purl);
897 }
898 downloaded_patches.push(serde_json::json!({
899 "purl": search_result.purl,
900 "uuid": search_result.uuid,
901 "action": "failed",
902 "error": e.to_string(),
903 }));
904 patches_failed += 1;
905 }
906 }
907 }
908
909 if let Err(e) = write_manifest(&manifest_path, &manifest).await {
911 let msg = format!("Error writing manifest: {e}");
912 let err_json = serde_json::json!({ "status": "error", "error": &msg });
913 if params.json {
914 print_json(&err_json);
915 } else {
916 eprintln!("{msg}");
917 }
918 return (1, err_json);
919 }
920
921 if !params.json && !params.silent {
922 eprintln!("\nPatches saved to {}", manifest_path.display());
923 eprintln!(" Added: {patches_added}");
924 if patches_skipped > 0 {
925 eprintln!(" Skipped: {patches_skipped}");
926 }
927 if patches_failed > 0 {
928 eprintln!(" Failed: {patches_failed}");
929 }
930 if !updates.is_empty() {
931 eprintln!(" Updated: {}", updates.len());
932 }
933 }
934
935 let mut apply_succeeded = false;
937 if !params.save_only && patches_added > 0 {
938 if !params.json && !params.silent {
939 eprintln!("\nApplying patches...");
940 }
941 let apply_args = super::apply::ApplyArgs {
942 common: crate::args::GlobalArgs {
943 cwd: params.cwd.clone(),
944 manifest_path: manifest_path.display().to_string(),
945 global: params.global,
946 global_prefix: params.global_prefix.clone(),
947 silent: params.json || params.silent,
948 download_mode: params.download_mode.clone(),
949 ..crate::args::GlobalArgs::default()
950 },
951 force: false,
952 };
953 let code = super::apply::run(apply_args).await;
954 apply_succeeded = code == 0;
955 if code != 0 && !params.json && !params.silent {
956 eprintln!("\nSome patches could not be applied.");
957 }
958 }
959
960 let mut result_json = serde_json::json!({
961 "status": if patches_failed > 0 { "partial_failure" } else { "success" },
962 "found": selected.len(),
963 "downloaded": patches_added,
964 "skipped": patches_skipped,
965 "failed": patches_failed,
966 "applied": if apply_succeeded { patches_added } else { 0 },
967 "updated": updates.len(),
968 "patches": downloaded_patches,
969 });
970 if !narrow_warnings.is_empty() {
974 result_json["warnings"] = serde_json::json!(narrow_warnings);
975 }
976
977 let exit_code = if patches_failed > 0 || (!apply_succeeded && patches_added > 0 && !params.save_only) { 1 } else { 0 };
978 (exit_code, result_json)
979}
980
981pub async fn run(args: GetArgs) -> i32 {
982 let type_flags = [args.id, args.cve, args.ghsa, args.package]
984 .iter()
985 .filter(|&&f| f)
986 .count();
987 if type_flags > 1 {
988 report_error(
989 args.common.json,
990 "Only one of --id, --cve, --ghsa, or --package can be specified",
991 );
992 return 1;
993 }
994 if args.one_off && args.save_only {
995 if args.common.json {
996 print_json(&serde_json::json!({
997 "status": "error",
998 "error": "--one-off and --save-only cannot be used together",
999 }));
1000 } else {
1001 eprintln!("Error: --one-off and --save-only cannot be used together");
1002 }
1003 return 1;
1004 }
1005
1006 apply_env_toggles(&args.common);
1007 let overrides = args.common.api_client_overrides();
1008 let (mut api_client, mut use_public_proxy) =
1009 get_api_client_with_overrides(overrides.clone()).await;
1010 let telemetry_token = api_client.api_token().cloned();
1011 let telemetry_org = api_client.org_slug().cloned();
1012 let download_mode = args.common.download_mode.clone();
1013 let mut fallback_to_proxy = false;
1018
1019 let effective_org_slug: Option<&str> = None;
1021
1022 let id_type = if args.id {
1024 IdentifierType::Uuid
1025 } else if args.cve {
1026 IdentifierType::Cve
1027 } else if args.ghsa {
1028 IdentifierType::Ghsa
1029 } else if args.package {
1030 IdentifierType::Package
1031 } else {
1032 match detect_identifier_type(&args.identifier) {
1033 Some(t) => t,
1034 None => {
1035 if !args.common.json {
1036 println!("Treating \"{}\" as a package name search", args.identifier);
1037 }
1038 IdentifierType::Package
1039 }
1040 }
1041 };
1042
1043 if id_type == IdentifierType::Uuid {
1045 if !args.common.json {
1046 println!("Fetching patch by UUID: {}", args.identifier);
1047 }
1048 let mut fetch_result = api_client
1049 .fetch_patch(effective_org_slug, &args.identifier)
1050 .await;
1051 if !use_public_proxy {
1055 if let Err(ref e) = fetch_result {
1056 if is_fallback_candidate(e) {
1057 eprintln!(
1058 "Warning: authenticated API returned {e}; \
1059 falling back to public patch API proxy (free patches only)."
1060 );
1061 api_client = build_proxy_fallback_client(&overrides);
1062 use_public_proxy = true;
1063 fallback_to_proxy = true;
1064 fetch_result = api_client
1065 .fetch_patch(effective_org_slug, &args.identifier)
1066 .await;
1067 }
1068 }
1069 }
1070 match fetch_result {
1071 Ok(Some(patch)) => {
1072 if patch.tier == "paid" && use_public_proxy {
1073 track_patch_fetch_failed(
1074 &patch.uuid,
1075 "paid_required",
1076 fallback_to_proxy,
1077 telemetry_token.as_deref(),
1078 telemetry_org.as_deref(),
1079 )
1080 .await;
1081 if args.common.json {
1082 print_json(&serde_json::json!({
1083 "status": "paid_required",
1084 "found": 1,
1085 "downloaded": 0,
1086 "applied": 0,
1087 "patches": [{
1088 "purl": patch.purl,
1089 "uuid": patch.uuid,
1090 "tier": "paid",
1091 }],
1092 }));
1093 } else {
1094 println!("\nThis patch requires a paid subscription to download.");
1095 println!("\n Patch: {}", patch.purl);
1096 println!(" Tier: paid");
1097 println!("\n Upgrade at: https://socket.dev/pricing\n");
1098 }
1099 return 0;
1100 }
1101
1102 track_patch_fetched(
1108 &patch.uuid,
1109 &patch.tier,
1110 &ecosystem_from_purl(&patch.purl),
1111 &download_mode,
1112 fallback_to_proxy,
1113 telemetry_token.as_deref(),
1114 telemetry_org.as_deref(),
1115 )
1116 .await;
1117 return save_and_apply_patch(&args, &patch.purl, &patch.uuid, effective_org_slug)
1119 .await;
1120 }
1121 Ok(None) => {
1122 track_patch_fetch_failed(
1123 &args.identifier,
1124 "not_found",
1125 fallback_to_proxy,
1126 telemetry_token.as_deref(),
1127 telemetry_org.as_deref(),
1128 )
1129 .await;
1130 if args.common.json {
1131 print_json(&empty_result_json("not_found"));
1132 } else {
1133 println!("No patch found with UUID: {}", args.identifier);
1134 }
1135 return 0;
1136 }
1137 Err(e) => {
1138 return report_fetch_failure(
1139 &args.identifier,
1140 e,
1141 fallback_to_proxy,
1142 telemetry_token.as_deref(),
1143 telemetry_org.as_deref(),
1144 args.common.json,
1145 )
1146 .await;
1147 }
1148 }
1149 }
1150
1151 let search_response: SearchResponse = match id_type {
1155 IdentifierType::Cve | IdentifierType::Ghsa | IdentifierType::Purl => {
1156 if !args.common.json {
1157 let label = match id_type {
1158 IdentifierType::Cve => "CVE",
1159 IdentifierType::Ghsa => "GHSA",
1160 IdentifierType::Purl => "PURL",
1161 _ => unreachable!(),
1162 };
1163 println!("Searching patches for {label}: {}", args.identifier);
1164 }
1165 let result = match id_type {
1166 IdentifierType::Cve => {
1167 api_client
1168 .search_patches_by_cve(effective_org_slug, &args.identifier)
1169 .await
1170 }
1171 IdentifierType::Ghsa => {
1172 api_client
1173 .search_patches_by_ghsa(effective_org_slug, &args.identifier)
1174 .await
1175 }
1176 IdentifierType::Purl => {
1177 api_client
1178 .search_patches_by_package(effective_org_slug, &args.identifier)
1179 .await
1180 }
1181 _ => unreachable!(),
1182 };
1183 match result {
1184 Ok(r) => r,
1185 Err(e) => {
1186 return report_fetch_failure(
1187 &args.identifier,
1188 e,
1189 fallback_to_proxy,
1190 telemetry_token.as_deref(),
1191 telemetry_org.as_deref(),
1192 args.common.json,
1193 )
1194 .await;
1195 }
1196 }
1197 }
1198 IdentifierType::Package => {
1199 if !args.common.json {
1200 println!("Enumerating packages...");
1201 }
1202 let crawler_options = CrawlerOptions {
1203 cwd: args.common.cwd.clone(),
1204 global: args.common.global,
1205 global_prefix: args.common.global_prefix.clone(),
1206 batch_size: 100,
1207 };
1208 let (all_packages, _) = crawl_all_ecosystems(&crawler_options).await;
1209
1210 if all_packages.is_empty() {
1211 if args.common.json {
1212 print_json(&empty_result_json("no_packages"));
1213 } else if args.common.global {
1214 println!("No global packages found.");
1215 } else {
1216 #[allow(unused_mut)]
1217 let mut install_cmds = String::from("npm/yarn/pnpm/pip");
1218 #[cfg(feature = "cargo")]
1219 install_cmds.push_str("/cargo");
1220 #[cfg(feature = "golang")]
1221 install_cmds.push_str("/go");
1222 #[cfg(feature = "maven")]
1223 install_cmds.push_str("/mvn");
1224 #[cfg(feature = "composer")]
1225 install_cmds.push_str("/composer");
1226 println!("No packages found. Run {install_cmds} install first.");
1227 }
1228 return 0;
1229 }
1230
1231 if !args.common.json {
1232 println!("Found {} packages", all_packages.len());
1233 }
1234
1235 let matches = fuzzy_match_packages(&args.identifier, &all_packages, 20);
1236
1237 if matches.is_empty() {
1238 if args.common.json {
1239 print_json(&empty_result_json("no_match"));
1240 } else {
1241 println!("No packages matching \"{}\" found.", args.identifier);
1242 }
1243 return 0;
1244 }
1245
1246 if !args.common.json {
1247 println!(
1248 "Found {} matching package(s), checking for available patches...",
1249 matches.len()
1250 );
1251 }
1252
1253 let best_match = &matches[0];
1255 match api_client
1256 .search_patches_by_package(effective_org_slug, &best_match.purl)
1257 .await
1258 {
1259 Ok(r) => r,
1260 Err(e) => {
1261 return report_fetch_failure(
1262 &args.identifier,
1263 e,
1264 fallback_to_proxy,
1265 telemetry_token.as_deref(),
1266 telemetry_org.as_deref(),
1267 args.common.json,
1268 )
1269 .await;
1270 }
1271 }
1272 }
1273 _ => unreachable!(),
1274 };
1275
1276 if search_response.patches.is_empty() {
1277 if args.common.json {
1278 print_json(&empty_result_json("not_found"));
1279 } else {
1280 println!(
1281 "No patches found for {}: {}",
1282 id_type, args.identifier
1283 );
1284 }
1285 return 0;
1286 }
1287
1288 if !args.common.json {
1289 display_search_results(&search_response.patches, search_response.can_access_paid_patches);
1290 }
1291
1292 let accessible: Vec<_> = search_response
1294 .patches
1295 .iter()
1296 .filter(|p| p.tier == "free" || search_response.can_access_paid_patches)
1297 .cloned()
1298 .collect();
1299
1300 if accessible.is_empty() {
1301 if args.common.json {
1302 print_json(&serde_json::json!({
1303 "status": "paid_required",
1304 "found": search_response.patches.len(),
1305 "downloaded": 0,
1306 "applied": 0,
1307 "patches": search_response.patches.iter().map(|p| serde_json::json!({
1308 "purl": p.purl,
1309 "uuid": p.uuid,
1310 "tier": p.tier,
1311 })).collect::<Vec<_>>(),
1312 }));
1313 } else {
1314 println!("\nAll available patches require a paid subscription.");
1315 println!("\n Upgrade at: https://socket.dev/pricing\n");
1316 }
1317 return 0;
1318 }
1319
1320 let selected = match select_patches(
1322 &accessible,
1323 search_response.can_access_paid_patches,
1324 args.common.json,
1325 ) {
1326 Ok(s) => s,
1327 Err(code) => return code,
1328 };
1329
1330 if selected.is_empty() {
1331 if !args.common.json {
1332 println!("No patches selected.");
1333 }
1334 return 0;
1335 }
1336
1337 let prompt = format!("Download {} patch(es)?", selected.len());
1339 if !confirm(&prompt, true, args.common.yes, args.common.json) {
1340 if !args.common.json {
1341 println!("Download cancelled.");
1342 }
1343 return 0;
1344 }
1345
1346 let params = DownloadParams {
1348 cwd: args.common.cwd.clone(),
1349 org: args.common.org.clone(),
1350 save_only: args.save_only,
1351 one_off: args.one_off,
1352 global: args.common.global,
1353 global_prefix: args.common.global_prefix.clone(),
1354 json: args.common.json,
1355 silent: false,
1356 download_mode: args.common.download_mode.clone(),
1357 api_overrides: args.common.api_client_overrides(),
1358 all_releases: args.all_releases,
1359 };
1360
1361 let (code, result_json) = download_and_apply_patches(&selected, ¶ms).await;
1362
1363 if args.common.json {
1364 println!("{}", serde_json::to_string_pretty(&result_json).unwrap());
1365 }
1366
1367 code
1368}
1369
1370fn display_search_results(patches: &[PatchSearchResult], can_access_paid: bool) {
1371 println!("\nFound patches:\n");
1372
1373 for (i, patch) in patches.iter().enumerate() {
1374 let tier_label = if patch.tier == "paid" {
1375 " [PAID]"
1376 } else {
1377 " [FREE]"
1378 };
1379 let access_label = if patch.tier == "paid" && !can_access_paid {
1380 " (no access)"
1381 } else {
1382 ""
1383 };
1384
1385 println!(" {}. {}{}{}", i + 1, patch.purl, tier_label, access_label);
1386 println!(" UUID: {}", patch.uuid);
1387 if !patch.description.is_empty() {
1388 let desc = truncate_with_ellipsis(&patch.description, 80);
1389 println!(" Description: {desc}");
1390 }
1391
1392 let vuln_ids: Vec<_> = patch.vulnerabilities.keys().collect();
1393 if !vuln_ids.is_empty() {
1394 let vuln_summary: Vec<String> = patch
1395 .vulnerabilities
1396 .iter()
1397 .map(|(id, vuln)| {
1398 let cves = if vuln.cves.is_empty() {
1399 id.to_string()
1400 } else {
1401 vuln.cves.join(", ")
1402 };
1403 format!("{cves} ({})", vuln.severity)
1404 })
1405 .collect();
1406 println!(" Fixes: {}", vuln_summary.join(", "));
1407 }
1408 println!();
1409 }
1410}
1411
1412async fn save_and_apply_patch(
1413 args: &GetArgs,
1414 _purl: &str,
1415 uuid: &str,
1416 _org_slug: Option<&str>,
1417) -> i32 {
1418 let (api_client, _) =
1420 get_api_client_with_overrides(args.common.api_client_overrides()).await;
1421 let effective_org: Option<&str> = None; let patch = match api_client.fetch_patch(effective_org, uuid).await {
1424 Ok(Some(p)) => p,
1425 Ok(None) => {
1426 if args.common.json {
1427 print_json(&empty_result_json("not_found"));
1428 } else {
1429 println!("No patch found with UUID: {uuid}");
1430 }
1431 return 0;
1432 }
1433 Err(e) => {
1434 report_error(args.common.json, e);
1435 return 1;
1436 }
1437 };
1438
1439 let socket_dir = args.common.cwd.join(".socket");
1440 let blobs_dir = socket_dir.join("blobs");
1441 let manifest_path = socket_dir.join("manifest.json");
1442
1443 if let Err(e) = tokio::fs::create_dir_all(&blobs_dir).await {
1444 report_error(args.common.json, format!("Failed to create blobs directory: {e}"));
1445 return 1;
1446 }
1447
1448 let mut manifest = match read_manifest(&manifest_path).await {
1449 Ok(Some(m)) => m,
1450 _ => PatchManifest::new(),
1451 };
1452
1453 let mut files = HashMap::new();
1458 for (file_path, file_info) in &patch.files {
1459 if let Some(after) = &file_info.after_hash {
1460 files.insert(
1461 file_path.clone(),
1462 PatchFileInfo {
1463 before_hash: file_info.before_hash.clone().unwrap_or_default(),
1464 after_hash: after.clone(),
1465 },
1466 );
1467 }
1468 }
1469
1470 if write_all_patch_blobs(&blobs_dir, &patch, args.common.json)
1471 .await
1472 .is_err()
1473 {
1474 if args.common.json {
1475 print_json(&serde_json::json!({
1476 "status": "error",
1477 "found": 1,
1478 "downloaded": 0,
1479 "applied": 0,
1480 "error": "Blob decode or write failed",
1481 "patches": [{
1482 "purl": patch.purl,
1483 "uuid": patch.uuid,
1484 "action": "failed",
1485 "error": "Blob decode or write failed",
1486 }],
1487 }));
1488 } else {
1489 eprintln!("Error: Blob decode or write failed for patch {}", patch.purl);
1490 }
1491 return 1;
1492 }
1493
1494 let added = manifest
1495 .patches
1496 .get(&patch.purl)
1497 .is_none_or(|p| p.uuid != patch.uuid);
1498
1499 manifest
1500 .patches
1501 .insert(patch.purl.clone(), build_patch_record(&patch, files));
1502
1503 if let Err(e) = write_manifest(&manifest_path, &manifest).await {
1504 report_error(args.common.json, format!("Error writing manifest: {e}"));
1505 return 1;
1506 }
1507
1508 if !args.common.json {
1509 println!("\nPatch saved to {}", manifest_path.display());
1510 if added {
1511 println!(" Added: 1");
1512 } else {
1513 println!(" Skipped: 1 (already exists)");
1514 }
1515 }
1516
1517 let mut apply_succeeded = false;
1518 if !args.save_only && added {
1519 if !args.common.json {
1520 println!("\nApplying patches...");
1521 }
1522 let apply_args = super::apply::ApplyArgs {
1523 common: crate::args::GlobalArgs {
1524 cwd: args.common.cwd.clone(),
1525 manifest_path: manifest_path.display().to_string(),
1526 global: args.common.global,
1527 global_prefix: args.common.global_prefix.clone(),
1528 silent: args.common.json,
1529 download_mode: args.common.download_mode.clone(),
1530 ..crate::args::GlobalArgs::default()
1531 },
1532 force: false,
1533 };
1534 let code = super::apply::run(apply_args).await;
1535 apply_succeeded = code == 0;
1536 if code != 0 && !args.common.json {
1537 eprintln!("\nSome patches could not be applied.");
1538 }
1539 }
1540
1541 if args.common.json {
1542 let mut patch_record = serde_json::json!({
1543 "purl": patch.purl,
1544 "uuid": patch.uuid,
1545 "action": if added { "added" } else { "skipped" },
1546 });
1547 if added {
1548 merge_metadata(&mut patch_record, patch_event_metadata(&patch));
1551 }
1552 println!("{}", serde_json::to_string_pretty(&serde_json::json!({
1553 "status": "success",
1554 "found": 1,
1555 "downloaded": if added { 1 } else { 0 },
1556 "applied": if apply_succeeded { 1 } else { 0 },
1557 "patches": [patch_record],
1558 })).unwrap());
1559 }
1560
1561 if !apply_succeeded && added && !args.save_only { 1 } else { 0 }
1562}
1563
1564fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
1565 let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1566 let mut table = [255u8; 256];
1567 for (i, &c) in chars.iter().enumerate() {
1568 table[c as usize] = i as u8;
1569 }
1570
1571 let input = input.as_bytes();
1572 let mut output = Vec::with_capacity(input.len() * 3 / 4);
1573
1574 let mut buf = 0u32;
1575 let mut bits = 0u32;
1576
1577 for &b in input {
1578 if b == b'=' || b == b'\n' || b == b'\r' {
1579 continue;
1580 }
1581 let val = table[b as usize];
1582 if val == 255 {
1583 return Err(format!("Invalid base64 character: {}", b as char));
1584 }
1585 buf = (buf << 6) | val as u32;
1586 bits += 6;
1587 if bits >= 8 {
1588 bits -= 8;
1589 output.push((buf >> bits) as u8);
1590 buf &= (1 << bits) - 1;
1591 }
1592 }
1593
1594 Ok(output)
1595}
1596
1597#[cfg(test)]
1598mod tests {
1599 use super::*;
1600 use socket_patch_core::api::types::VulnerabilityResponse;
1601 use std::collections::HashMap;
1602
1603 #[test]
1606 fn detect_uuid_lowercase() {
1607 assert_eq!(
1608 detect_identifier_type("80630680-4da6-45f9-bba8-b888e0ffd58c"),
1609 Some(IdentifierType::Uuid)
1610 );
1611 }
1612
1613 #[test]
1614 fn detect_uuid_uppercase() {
1615 assert_eq!(
1617 detect_identifier_type("80630680-4DA6-45F9-BBA8-B888E0FFD58C"),
1618 Some(IdentifierType::Uuid)
1619 );
1620 }
1621
1622 #[test]
1623 fn detect_cve_uppercase() {
1624 assert_eq!(
1625 detect_identifier_type("CVE-2021-44906"),
1626 Some(IdentifierType::Cve)
1627 );
1628 }
1629
1630 #[test]
1631 fn detect_cve_lowercase() {
1632 assert_eq!(
1634 detect_identifier_type("cve-2021-44906"),
1635 Some(IdentifierType::Cve)
1636 );
1637 }
1638
1639 #[test]
1640 fn detect_ghsa_uppercase() {
1641 assert_eq!(
1642 detect_identifier_type("GHSA-abcd-1234-wxyz"),
1643 Some(IdentifierType::Ghsa)
1644 );
1645 }
1646
1647 #[test]
1648 fn detect_ghsa_lowercase() {
1649 assert_eq!(
1651 detect_identifier_type("ghsa-abcd-1234-wxyz"),
1652 Some(IdentifierType::Ghsa)
1653 );
1654 }
1655
1656 #[test]
1657 fn detect_purl() {
1658 assert_eq!(
1659 detect_identifier_type("pkg:npm/foo@1.0"),
1660 Some(IdentifierType::Purl)
1661 );
1662 }
1663
1664 #[test]
1665 fn detect_package_name_returns_none() {
1666 assert_eq!(detect_identifier_type("minimist"), None);
1669 }
1670
1671 #[test]
1672 fn detect_malformed_cve_returns_none() {
1673 assert_eq!(detect_identifier_type("CVE-not-a-year"), None);
1674 }
1675
1676 #[test]
1677 fn detect_empty_string_returns_none() {
1678 assert_eq!(detect_identifier_type(""), None);
1679 }
1680
1681 fn mk_patch(
1684 uuid: &str,
1685 purl: &str,
1686 tier: &str,
1687 published_at: &str,
1688 ) -> PatchSearchResult {
1689 PatchSearchResult {
1690 uuid: uuid.into(),
1691 purl: purl.into(),
1692 published_at: published_at.into(),
1693 description: format!("desc-{uuid}"),
1694 license: "MIT".into(),
1695 tier: tier.into(),
1696 vulnerabilities: HashMap::<String, VulnerabilityResponse>::new(),
1697 }
1698 }
1699
1700 #[test]
1701 fn select_free_user_one_free_patch_returns_it() {
1702 let patches = vec![mk_patch("u1", "pkg:npm/foo@1.0", "free", "2024-01-01")];
1703 let out = select_patches(&patches, false, false).expect("ok");
1704 assert_eq!(out.len(), 1);
1705 assert_eq!(out[0].uuid, "u1");
1706 }
1707
1708 #[test]
1709 fn select_paid_user_prefers_paid_over_free_same_purl() {
1710 let patches = vec![
1711 mk_patch("free1", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1712 mk_patch("paid1", "pkg:npm/foo@1.0", "paid", "2024-01-01"),
1713 ];
1714 let out = select_patches(&patches, true, false).expect("ok");
1715 assert_eq!(out.len(), 1);
1716 assert_eq!(out[0].uuid, "paid1");
1718 assert_eq!(out[0].tier, "paid");
1719 }
1720
1721 #[test]
1722 fn select_paid_user_picks_most_recent_paid() {
1723 let patches = vec![
1724 mk_patch("old", "pkg:npm/foo@1.0", "paid", "2024-01-01"),
1725 mk_patch("new", "pkg:npm/foo@1.0", "paid", "2024-06-01"),
1726 ];
1727 let out = select_patches(&patches, true, false).expect("ok");
1728 assert_eq!(out.len(), 1);
1729 assert_eq!(out[0].uuid, "new");
1730 }
1731
1732 #[test]
1733 fn select_paid_user_falls_back_to_most_recent_free_when_no_paid() {
1734 let patches = vec![
1735 mk_patch("old", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1736 mk_patch("new", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1737 ];
1738 let out = select_patches(&patches, true, false).expect("ok");
1739 assert_eq!(out.len(), 1);
1740 assert_eq!(out[0].uuid, "new");
1741 }
1742
1743 #[test]
1744 fn select_free_user_multi_free_json_mode_errors() {
1745 let patches = vec![
1748 mk_patch("a", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1749 mk_patch("b", "pkg:npm/foo@1.0", "free", "2024-06-01"),
1750 ];
1751 let err = select_patches(&patches, false, true).expect_err("should fail");
1752 assert_eq!(err, 1);
1753 }
1754
1755 #[test]
1756 fn select_empty_input_returns_empty() {
1757 let out = select_patches(&[], false, false).expect("ok");
1758 assert!(out.is_empty());
1759 let out = select_patches(&[], true, false).expect("ok");
1760 assert!(out.is_empty());
1761 let out = select_patches(&[], false, true).expect("ok");
1762 assert!(out.is_empty());
1763 }
1764
1765 #[test]
1766 fn select_free_user_paid_filtered_out_then_single_free_auto_selects() {
1767 let patches = vec![
1771 mk_patch("paid", "pkg:npm/foo@1.0", "paid", "2024-06-01"),
1772 mk_patch("free", "pkg:npm/foo@1.0", "free", "2024-01-01"),
1773 ];
1774 let out = select_patches(&patches, false, false).expect("ok");
1775 assert_eq!(out.len(), 1);
1776 assert_eq!(out[0].uuid, "free");
1777 assert_eq!(out[0].tier, "free");
1778 }
1779
1780 fn manifest_with_entry(purl: &str, uuid: &str) -> PatchManifest {
1785 let mut m = PatchManifest::new();
1786 m.patches.insert(
1787 purl.to_string(),
1788 PatchRecord {
1789 uuid: uuid.to_string(),
1790 exported_at: String::new(),
1791 files: HashMap::new(),
1792 vulnerabilities: HashMap::new(),
1793 description: String::new(),
1794 license: String::new(),
1795 tier: "free".to_string(),
1796 },
1797 );
1798 m
1799 }
1800
1801 #[test]
1802 fn decide_patch_action_added_when_purl_absent() {
1803 let manifest = PatchManifest::new();
1804 assert_eq!(
1805 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-a"),
1806 PatchAction::Added,
1807 );
1808 }
1809
1810 #[test]
1811 fn decide_patch_action_skipped_when_same_uuid() {
1812 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1813 assert_eq!(
1814 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-a"),
1815 PatchAction::Skipped,
1816 );
1817 }
1818
1819 #[test]
1820 fn decide_patch_action_updated_when_different_uuid() {
1821 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1822 assert_eq!(
1823 decide_patch_action(&manifest, "pkg:npm/foo@1.0", "uuid-b"),
1824 PatchAction::Updated {
1825 old_uuid: "uuid-a".to_string()
1826 },
1827 );
1828 }
1829
1830 #[test]
1831 fn decide_patch_action_added_for_different_purl_even_with_overlapping_manifest() {
1832 let manifest = manifest_with_entry("pkg:npm/foo@1.0", "uuid-a");
1836 assert_eq!(
1837 decide_patch_action(&manifest, "pkg:npm/bar@2.0", "uuid-a"),
1838 PatchAction::Added,
1839 );
1840 }
1841
1842 #[test]
1849 fn severity_rank_orders_canonical_labels() {
1850 assert!(severity_rank("critical") > severity_rank("high"));
1851 assert!(severity_rank("high") > severity_rank("medium"));
1852 assert!(severity_rank("medium") > severity_rank("low"));
1853 assert_eq!(severity_rank("moderate"), severity_rank("medium"));
1855 assert!(severity_rank("low") > severity_rank(""));
1857 assert!(severity_rank("low") > severity_rank("unknown"));
1858 }
1859
1860 #[test]
1861 fn max_vuln_severity_picks_highest() {
1862 let mut vulns = HashMap::new();
1863 vulns.insert(
1864 "GHSA-low".into(),
1865 VulnerabilityResponse {
1866 cves: vec!["CVE-low".into()],
1867 summary: String::new(),
1868 severity: "low".into(),
1869 description: String::new(),
1870 },
1871 );
1872 vulns.insert(
1873 "GHSA-crit".into(),
1874 VulnerabilityResponse {
1875 cves: vec!["CVE-crit".into()],
1876 summary: String::new(),
1877 severity: "critical".into(),
1878 description: String::new(),
1879 },
1880 );
1881 vulns.insert(
1882 "GHSA-mod".into(),
1883 VulnerabilityResponse {
1884 cves: vec!["CVE-mod".into()],
1885 summary: String::new(),
1886 severity: "moderate".into(),
1887 description: String::new(),
1888 },
1889 );
1890 assert_eq!(max_vuln_severity(&vulns).as_deref(), Some("critical"));
1891 }
1892
1893 #[test]
1894 fn max_vuln_severity_returns_none_for_empty() {
1895 assert_eq!(max_vuln_severity(&HashMap::new()), None);
1896 }
1897
1898 #[test]
1899 fn patch_event_metadata_includes_all_keys() {
1900 let mut vulns = HashMap::new();
1901 vulns.insert(
1902 "GHSA-aaaa-bbbb-cccc".into(),
1903 VulnerabilityResponse {
1904 cves: vec!["CVE-2024-12345".into()],
1905 summary: "Prototype Pollution".into(),
1906 severity: "high".into(),
1907 description: "merge() does not check Object.prototype".into(),
1908 },
1909 );
1910 let patch = PatchResponse {
1911 uuid: "11111111-1111-4111-8111-111111111111".into(),
1912 purl: "pkg:npm/minimist@1.2.2".into(),
1913 published_at: "2024-01-01T00:00:00Z".into(),
1914 files: HashMap::new(),
1915 vulnerabilities: vulns,
1916 description: "Fixes prototype pollution in minimist".into(),
1917 license: "MIT".into(),
1918 tier: "free".into(),
1919 };
1920 let meta = patch_event_metadata(&patch);
1921 assert_eq!(meta["description"], "Fixes prototype pollution in minimist");
1922 assert_eq!(meta["license"], "MIT");
1923 assert_eq!(meta["tier"], "free");
1924 assert_eq!(meta["exportedAt"], "2024-01-01T00:00:00Z");
1925 assert_eq!(meta["severity"], "high");
1926 let vulns_out = meta["vulnerabilities"].as_array().unwrap();
1927 assert_eq!(vulns_out.len(), 1);
1928 assert_eq!(vulns_out[0]["id"], "GHSA-aaaa-bbbb-cccc");
1929 assert_eq!(vulns_out[0]["cves"][0], "CVE-2024-12345");
1930 assert_eq!(vulns_out[0]["severity"], "high");
1931 assert_eq!(vulns_out[0]["summary"], "Prototype Pollution");
1932 }
1933
1934 #[test]
1935 fn patch_event_metadata_sorts_vulnerabilities_by_id() {
1936 let mut vulns = HashMap::new();
1940 for id in ["GHSA-zzz", "GHSA-aaa", "GHSA-mmm"] {
1941 vulns.insert(
1942 id.into(),
1943 VulnerabilityResponse {
1944 cves: Vec::new(),
1945 summary: String::new(),
1946 severity: "low".into(),
1947 description: String::new(),
1948 },
1949 );
1950 }
1951 let patch = PatchResponse {
1952 uuid: String::new(),
1953 purl: String::new(),
1954 published_at: String::new(),
1955 files: HashMap::new(),
1956 vulnerabilities: vulns,
1957 description: String::new(),
1958 license: String::new(),
1959 tier: String::new(),
1960 };
1961 let meta = patch_event_metadata(&patch);
1962 let ids: Vec<&str> = meta["vulnerabilities"]
1963 .as_array()
1964 .unwrap()
1965 .iter()
1966 .map(|v| v["id"].as_str().unwrap())
1967 .collect();
1968 assert_eq!(ids, ["GHSA-aaa", "GHSA-mmm", "GHSA-zzz"]);
1969 }
1970
1971 #[test]
1972 fn patch_event_metadata_omits_severity_when_no_vulns() {
1973 let patch = PatchResponse {
1974 uuid: String::new(),
1975 purl: String::new(),
1976 published_at: "ts".into(),
1977 files: HashMap::new(),
1978 vulnerabilities: HashMap::new(),
1979 description: "desc".into(),
1980 license: "MIT".into(),
1981 tier: "free".into(),
1982 };
1983 let meta = patch_event_metadata(&patch);
1984 assert!(meta.as_object().unwrap().get("severity").is_none());
1988 assert_eq!(meta["vulnerabilities"].as_array().unwrap().len(), 0);
1991 }
1992
1993 #[test]
1999 fn truncate_short_string_unchanged() {
2000 assert_eq!(truncate_with_ellipsis("hello", 60), "hello");
2001 }
2002
2003 #[test]
2004 fn truncate_at_limit_unchanged() {
2005 let s = "a".repeat(60);
2006 assert_eq!(truncate_with_ellipsis(&s, 60), s);
2007 }
2008
2009 #[test]
2010 fn truncate_long_ascii_adds_ellipsis_and_respects_limit() {
2011 let s = "a".repeat(100);
2012 let out = truncate_with_ellipsis(&s, 60);
2013 assert_eq!(out.chars().count(), 60);
2015 assert!(out.ends_with("..."));
2016 assert_eq!(out, format!("{}...", "a".repeat(57)));
2017 }
2018
2019 #[test]
2020 fn truncate_multibyte_does_not_panic_and_is_char_safe() {
2021 let s = "日".repeat(30);
2026 let out = truncate_with_ellipsis(&s, 80);
2027 assert_eq!(out, s);
2028 }
2029
2030 #[test]
2031 fn truncate_multibyte_long_truncates_on_char_boundary() {
2032 let s = "é".repeat(100);
2035 let out = truncate_with_ellipsis(&s, 80);
2036 assert_eq!(out.chars().count(), 80);
2037 assert!(out.ends_with("..."));
2038 assert_eq!(out, format!("{}...", "é".repeat(77)));
2039 }
2040
2041 #[test]
2047 fn short_uuid_truncates_normal_uuid() {
2048 assert_eq!(short_uuid("80630680-4da6-45f9-bba8-b888e0ffd58c"), "80630680");
2049 }
2050
2051 #[test]
2052 fn short_uuid_returns_whole_string_when_shorter_than_eight() {
2053 assert_eq!(short_uuid("abc"), "abc");
2055 assert_eq!(short_uuid(""), "");
2056 }
2057
2058 #[test]
2059 fn short_uuid_does_not_panic_on_multibyte_boundary() {
2060 let s = "ab€cd"; assert_eq!(short_uuid(s), s);
2066 let s2 = "abcdef€"; assert_eq!(short_uuid(s2), s2);
2069 }
2070}