1use std::collections::HashSet;
2
3use reqwest::header::{self, HeaderMap, HeaderValue};
4use reqwest::StatusCode;
5use serde::Serialize;
6
7use crate::api::types::*;
8use crate::constants::{
9 DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT as USER_AGENT_VALUE,
10};
11use crate::utils::env_compat::read_env_with_legacy;
12
13fn is_debug_enabled() -> bool {
16 match read_env_with_legacy("SOCKET_DEBUG", "SOCKET_PATCH_DEBUG") {
17 Some(val) => val == "1" || val == "true",
18 None => false,
19 }
20}
21
22fn debug_log(message: &str) {
24 if is_debug_enabled() {
25 eprintln!("[socket-patch debug] {}", message);
26 }
27}
28
29fn get_severity_order(severity: Option<&str>) -> u8 {
31 match severity.map(|s| s.to_lowercase()).as_deref() {
32 Some("critical") => 0,
33 Some("high") => 1,
34 Some("medium") => 2,
35 Some("low") => 3,
36 _ => 4,
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ApiClientOptions {
43 pub api_url: String,
44 pub api_token: Option<String>,
45 pub use_public_proxy: bool,
48 pub org_slug: Option<String>,
51}
52
53#[derive(Debug, Clone)]
59pub struct ApiClient {
60 client: reqwest::Client,
61 api_url: String,
62 api_token: Option<String>,
63 use_public_proxy: bool,
64 org_slug: Option<String>,
65}
66
67#[derive(Serialize)]
69struct BatchSearchBody {
70 components: Vec<BatchComponent>,
71}
72
73#[derive(Serialize)]
74struct BatchComponent {
75 purl: String,
76}
77
78impl ApiClient {
79 pub fn new(options: ApiClientOptions) -> Self {
84 let api_url = options.api_url.trim_end_matches('/').to_string();
85
86 let mut default_headers = HeaderMap::new();
87 default_headers.insert(
88 header::USER_AGENT,
89 HeaderValue::from_static(USER_AGENT_VALUE),
90 );
91 default_headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
92
93 if let Some(ref token) = options.api_token {
94 if let Ok(hv) = HeaderValue::from_str(&format!("Bearer {}", token)) {
95 default_headers.insert(header::AUTHORIZATION, hv);
96 }
97 }
98
99 let client = reqwest::Client::builder()
100 .default_headers(default_headers)
101 .build()
102 .expect("failed to build reqwest client");
103
104 Self {
105 client,
106 api_url,
107 api_token: options.api_token,
108 use_public_proxy: options.use_public_proxy,
109 org_slug: options.org_slug,
110 }
111 }
112
113 pub fn api_token(&self) -> Option<&String> {
115 self.api_token.as_ref()
116 }
117
118 pub fn org_slug(&self) -> Option<&String> {
120 self.org_slug.as_ref()
121 }
122
123 async fn get_json<T: serde::de::DeserializeOwned>(
127 &self,
128 path: &str,
129 ) -> Result<Option<T>, ApiError> {
130 let url = format!("{}{}", self.api_url, path);
131 debug_log(&format!("GET {}", url));
132
133 let resp = self
134 .client
135 .get(&url)
136 .send()
137 .await
138 .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
139
140 Self::handle_json_response(resp, self.use_public_proxy).await
141 }
142
143 async fn post_json<T: serde::de::DeserializeOwned, B: Serialize>(
145 &self,
146 path: &str,
147 body: &B,
148 ) -> Result<Option<T>, ApiError> {
149 let url = format!("{}{}", self.api_url, path);
150 debug_log(&format!("POST {}", url));
151
152 let resp = self
153 .client
154 .post(&url)
155 .header(header::CONTENT_TYPE, "application/json")
156 .json(body)
157 .send()
158 .await
159 .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
160
161 Self::handle_json_response(resp, self.use_public_proxy).await
162 }
163
164 async fn handle_json_response<T: serde::de::DeserializeOwned>(
166 resp: reqwest::Response,
167 use_public_proxy: bool,
168 ) -> Result<Option<T>, ApiError> {
169 let status = resp.status();
170
171 match status {
172 StatusCode::OK => {
173 let body = resp
174 .json::<T>()
175 .await
176 .map_err(|e| ApiError::Parse(format!("Failed to parse response: {}", e)))?;
177 Ok(Some(body))
178 }
179 StatusCode::NOT_FOUND => Ok(None),
180 StatusCode::UNAUTHORIZED => Err(ApiError::Unauthorized(
181 "Unauthorized: Invalid API token".into(),
182 )),
183 StatusCode::FORBIDDEN => {
184 let msg = if use_public_proxy {
185 "Forbidden: This patch is only available to paid subscribers. \
186 Sign up at https://socket.dev to access paid patches."
187 } else {
188 "Forbidden: Access denied. This may be a paid patch or \
189 you may not have access to this organization."
190 };
191 Err(ApiError::Forbidden(msg.into()))
192 }
193 StatusCode::TOO_MANY_REQUESTS => Err(ApiError::RateLimited(
194 "Rate limit exceeded. Please try again later.".into(),
195 )),
196 _ => {
197 let text = resp.text().await.unwrap_or_default();
198 Err(ApiError::Other(format!(
199 "API request failed with status {}: {}",
200 status.as_u16(),
201 text
202 )))
203 }
204 }
205 }
206
207 pub async fn fetch_patch(
213 &self,
214 org_slug: Option<&str>,
215 uuid: &str,
216 ) -> Result<Option<PatchResponse>, ApiError> {
217 let path = if self.use_public_proxy {
218 format!("/patch/view/{}", uuid)
219 } else {
220 let slug = org_slug.or(self.org_slug.as_deref()).unwrap_or("default");
221 format!("/v0/orgs/{}/patches/view/{}", slug, uuid)
222 };
223 self.get_json(&path).await
224 }
225
226 async fn search_patches_by_route(
230 &self,
231 org_slug: Option<&str>,
232 route: &str,
233 identifier: &str,
234 ) -> Result<SearchResponse, ApiError> {
235 let encoded = urlencoding_encode(identifier);
236 let path = if self.use_public_proxy {
237 format!("/patch/{route}/{encoded}")
238 } else {
239 let slug = org_slug.or(self.org_slug.as_deref()).unwrap_or("default");
240 format!("/v0/orgs/{slug}/patches/{route}/{encoded}")
241 };
242 let result = self.get_json::<SearchResponse>(&path).await?;
243 Ok(result.unwrap_or_else(|| SearchResponse {
244 patches: Vec::new(),
245 can_access_paid_patches: false,
246 }))
247 }
248
249 pub async fn search_patches_by_cve(
251 &self,
252 org_slug: Option<&str>,
253 cve_id: &str,
254 ) -> Result<SearchResponse, ApiError> {
255 self.search_patches_by_route(org_slug, "by-cve", cve_id)
256 .await
257 }
258
259 pub async fn search_patches_by_ghsa(
261 &self,
262 org_slug: Option<&str>,
263 ghsa_id: &str,
264 ) -> Result<SearchResponse, ApiError> {
265 self.search_patches_by_route(org_slug, "by-ghsa", ghsa_id)
266 .await
267 }
268
269 pub async fn search_patches_by_package(
274 &self,
275 org_slug: Option<&str>,
276 purl: &str,
277 ) -> Result<SearchResponse, ApiError> {
278 self.search_patches_by_route(org_slug, "by-package", purl)
279 .await
280 }
281
282 pub async fn search_patches_batch(
291 &self,
292 org_slug: Option<&str>,
293 purls: &[String],
294 ) -> Result<BatchSearchResponse, ApiError> {
295 if !self.use_public_proxy {
296 let slug = org_slug.or(self.org_slug.as_deref()).unwrap_or("default");
297 let path = format!("/v0/orgs/{}/patches/batch", slug);
298 let body = BatchSearchBody {
299 components: purls
300 .iter()
301 .map(|p| BatchComponent { purl: p.clone() })
302 .collect(),
303 };
304 let result = self
305 .post_json::<BatchSearchResponse, _>(&path, &body)
306 .await?;
307 return Ok(result.unwrap_or_else(|| BatchSearchResponse {
308 packages: Vec::new(),
309 can_access_paid_patches: false,
310 }));
311 }
312
313 self.search_patches_batch_via_individual_queries(purls)
315 .await
316 }
317
318 async fn search_patches_batch_via_individual_queries(
324 &self,
325 purls: &[String],
326 ) -> Result<BatchSearchResponse, ApiError> {
327 const CONCURRENCY_LIMIT: usize = 10;
328
329 let mut all_results: Vec<(String, Option<SearchResponse>)> = Vec::new();
331
332 for chunk in purls.chunks(CONCURRENCY_LIMIT) {
333 let mut join_set = tokio::task::JoinSet::new();
335
336 for purl in chunk {
337 let purl = purl.clone();
338 let client = self.clone();
339 join_set.spawn(async move {
340 let resp = client.search_patches_by_package(None, &purl).await;
341 match resp {
342 Ok(r) => (purl, Some(r)),
343 Err(e) => {
344 debug_log(&format!("Error fetching patches for {}: {}", purl, e));
345 (purl, None)
346 }
347 }
348 });
349 }
350
351 while let Some(result) = join_set.join_next().await {
352 match result {
353 Ok(pair) => all_results.push(pair),
354 Err(e) => {
355 debug_log(&format!("Task join error: {}", e));
356 }
357 }
358 }
359 }
360
361 Ok(assemble_batch_from_individual(all_results))
363 }
364
365 pub async fn fetch_organizations(
367 &self,
368 ) -> Result<Vec<crate::api::types::OrganizationInfo>, ApiError> {
369 let path = "/v0/organizations";
370 match self
371 .get_json::<crate::api::types::OrganizationsResponse>(path)
372 .await?
373 {
374 Some(resp) => Ok(resp.organizations.into_values().collect()),
375 None => Ok(Vec::new()),
376 }
377 }
378
379 pub async fn resolve_org_slug(&self) -> Result<String, ApiError> {
385 let orgs = self.fetch_organizations().await?;
386 select_org_slug(orgs)
387 }
388
389 pub async fn fetch_blob(&self, hash: &str) -> Result<Option<Vec<u8>>, ApiError> {
395 if !is_valid_sha256_hex(hash) {
397 return Err(ApiError::InvalidHash(format!(
398 "Invalid hash format: {}. Expected SHA256 hash (64 hex characters).",
399 hash
400 )));
401 }
402 self.fetch_binary("blob", "blob", hash).await
403 }
404
405 pub async fn fetch_diff(&self, uuid: &str) -> Result<Option<Vec<u8>>, ApiError> {
411 if !is_valid_uuid(uuid) {
412 return Err(ApiError::InvalidHash(format!(
413 "Invalid patch UUID: {}",
414 uuid
415 )));
416 }
417 self.fetch_binary("diff", "diff", uuid).await
418 }
419
420 pub async fn fetch_package(&self, uuid: &str) -> Result<Option<Vec<u8>>, ApiError> {
424 if !is_valid_uuid(uuid) {
425 return Err(ApiError::InvalidHash(format!(
426 "Invalid patch UUID: {}",
427 uuid
428 )));
429 }
430 self.fetch_binary("package", "package", uuid).await
431 }
432
433 fn binary_url(&self, kind: &str, identifier: &str) -> (String, bool) {
447 if self.api_token.is_some() && self.org_slug.is_some() && !self.use_public_proxy {
448 let slug = self.org_slug.as_deref().unwrap();
449 let u = format!(
450 "{}/v0/orgs/{}/patches/{}/{}",
451 self.api_url, slug, kind, identifier
452 );
453 (u, true)
454 } else {
455 let base = if self.use_public_proxy {
456 self.api_url.clone()
457 } else {
458 read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
459 .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
460 };
461 let u = format!(
462 "{}/patch/{}/{}",
463 base.trim_end_matches('/'),
464 kind,
465 identifier
466 );
467 (u, false)
468 }
469 }
470
471 async fn fetch_binary(
477 &self,
478 kind: &str,
479 label: &str,
480 identifier: &str,
481 ) -> Result<Option<Vec<u8>>, ApiError> {
482 let (url, use_auth) = self.binary_url(kind, identifier);
483
484 debug_log(&format!("GET {} {}", label, url));
485
486 let resp = if use_auth {
490 self.client
491 .get(&url)
492 .header(header::ACCEPT, "application/octet-stream")
493 .send()
494 .await
495 } else {
496 let mut headers = HeaderMap::new();
497 headers.insert(
498 header::USER_AGENT,
499 HeaderValue::from_static(USER_AGENT_VALUE),
500 );
501 headers.insert(
502 header::ACCEPT,
503 HeaderValue::from_static("application/octet-stream"),
504 );
505
506 let plain_client = reqwest::Client::builder()
507 .default_headers(headers)
508 .build()
509 .expect("failed to build plain reqwest client");
510
511 plain_client.get(&url).send().await
512 };
513
514 let resp = resp.map_err(|e| {
515 ApiError::Network(format!(
516 "Network error fetching {} {}: {}",
517 label, identifier, e
518 ))
519 })?;
520
521 let status = resp.status();
522
523 match status {
524 StatusCode::OK => {
525 let bytes = resp.bytes().await.map_err(|e| {
526 ApiError::Network(format!(
527 "Error reading {} body for {}: {}",
528 label, identifier, e
529 ))
530 })?;
531 Ok(Some(bytes.to_vec()))
532 }
533 StatusCode::NOT_FOUND => Ok(None),
534 _ => {
535 let text = resp.text().await.unwrap_or_default();
536 Err(ApiError::Other(format!(
537 "Failed to fetch {} {}: status {} - {}",
538 label,
539 identifier,
540 status.as_u16(),
541 text,
542 )))
543 }
544 }
545 }
546}
547
548#[derive(Debug, Clone, Default)]
556pub struct ApiClientEnvOverrides {
557 pub api_url: Option<String>,
558 pub api_token: Option<String>,
559 pub org_slug: Option<String>,
560 pub proxy_url: Option<String>,
561}
562
563pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
584 get_api_client_with_overrides(ApiClientEnvOverrides {
585 org_slug: org_slug.map(String::from),
586 ..ApiClientEnvOverrides::default()
587 })
588 .await
589}
590
591pub async fn get_api_client_with_overrides(overrides: ApiClientEnvOverrides) -> (ApiClient, bool) {
597 let api_token = overrides
598 .api_token
599 .or_else(|| std::env::var("SOCKET_API_TOKEN").ok())
600 .filter(|t| !t.is_empty());
601 let resolved_org_slug = overrides
602 .org_slug
603 .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok());
604
605 if api_token.is_none() {
606 let proxy_url = overrides.proxy_url.unwrap_or_else(|| {
607 read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
608 .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
609 });
610 eprintln!("No SOCKET_API_TOKEN set. Using public patch API proxy (free patches only).");
611 let client = ApiClient::new(ApiClientOptions {
612 api_url: proxy_url,
613 api_token: None,
614 use_public_proxy: true,
615 org_slug: None,
616 });
617 return (client, true);
618 }
619
620 if let Some(ref t) = api_token {
623 if let Some(msg) = validate_token_shape(t) {
624 eprintln!("{msg}");
625 }
626 }
627
628 let api_url = overrides
629 .api_url
630 .or_else(|| std::env::var("SOCKET_API_URL").ok())
631 .unwrap_or_else(|| DEFAULT_SOCKET_API_URL.to_string());
632
633 let final_org_slug = if resolved_org_slug.is_some() {
635 resolved_org_slug
636 } else {
637 let temp_client = ApiClient::new(ApiClientOptions {
638 api_url: api_url.clone(),
639 api_token: api_token.clone(),
640 use_public_proxy: false,
641 org_slug: None,
642 });
643 match temp_client.resolve_org_slug().await {
644 Ok(slug) => Some(slug),
645 Err(e) => {
646 eprintln!("Warning: Could not auto-detect organization: {e}");
647 if matches!(e, ApiError::Unauthorized(_)) {
648 if let Some(ref t) = api_token {
649 if looks_like_token_hash(t) {
650 eprintln!(
651 " Hint: SOCKET_API_TOKEN starts with `{}-` \
652 which is the stored hash format. Set it to \
653 the raw `sktsec_..._api` value instead.",
654 t.split('-').next().unwrap_or("sha512")
655 );
656 }
657 }
658 }
659 None
660 }
661 }
662 };
663
664 let client = ApiClient::new(ApiClientOptions {
665 api_url,
666 api_token,
667 use_public_proxy: false,
668 org_slug: final_org_slug,
669 });
670 (client, false)
671}
672
673pub fn build_proxy_fallback_client(overrides: &ApiClientEnvOverrides) -> ApiClient {
681 let proxy_url = overrides.proxy_url.clone().unwrap_or_else(|| {
682 read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
683 .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
684 });
685 ApiClient::new(ApiClientOptions {
686 api_url: proxy_url,
687 api_token: None,
688 use_public_proxy: true,
689 org_slug: None,
690 })
691}
692
693pub fn looks_like_token_hash(token: &str) -> bool {
700 matches!(
701 token.split_once('-'),
702 Some(("sha256" | "sha384" | "sha512", _))
703 )
704}
705
706pub fn validate_token_shape(token: &str) -> Option<String> {
722 let has_prefix = token.starts_with("sktsec_");
723 let has_suffix = token.ends_with("_api") || token.ends_with("_agent");
724 let plausible_len = token.len() >= 55;
725 if has_prefix && has_suffix && plausible_len {
726 return None;
727 }
728 let len = token.len();
729 let head: String = token.chars().take(8).collect();
730 let tail_start = len.saturating_sub(4);
731 let tail: String = token.chars().skip(tail_start).collect();
732 let preview = if len <= 12 {
733 token.to_string()
734 } else {
735 format!("{head}...{tail}")
736 };
737 let hash_hint = if looks_like_token_hash(token) {
738 "\n That value looks like an SRI-format hash (sha###-<base64>) — \
739 the server stores the *hash* of your token, not what you should \
740 set here. Use the raw `sktsec_..._api` value shown when the token \
741 was generated."
742 } else {
743 ""
744 };
745 Some(format!(
746 "Warning: SOCKET_API_TOKEN does not look like a Socket API token \
747 (expected `sktsec_<44 chars>_api`).{hash_hint}\n \
748 Got: {preview} ({len} chars). Continuing anyway; the server may \
749 reject this with 401."
750 ))
751}
752
753pub fn is_fallback_candidate(err: &ApiError) -> bool {
758 matches!(err, ApiError::Unauthorized(_) | ApiError::Forbidden(_))
759}
760
761fn select_org_slug(mut orgs: Vec<crate::api::types::OrganizationInfo>) -> Result<String, ApiError> {
771 orgs.sort_by(|a, b| a.slug.cmp(&b.slug));
772 match orgs.len() {
773 0 => Err(ApiError::Other(
774 "No organizations found for this API token.".into(),
775 )),
776 1 => Ok(orgs.into_iter().next().unwrap().slug),
777 _ => {
778 let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect();
779 let first = orgs[0].slug.clone();
780 eprintln!(
781 "Multiple organizations found: {}. Using \"{}\". \
782 Pass --org to select a different one.",
783 slugs.join(", "),
784 first
785 );
786 Ok(first)
787 }
788 }
789}
790
791fn urlencoding_encode(input: &str) -> String {
795 let mut out = String::with_capacity(input.len());
797 for byte in input.bytes() {
798 match byte {
799 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
800 out.push(byte as char)
801 }
802 _ => {
803 out.push('%');
804 out.push_str(&format!("{:02X}", byte));
805 }
806 }
807 }
808 out
809}
810
811fn truncate_to_chars(s: &str, max_chars: usize) -> String {
814 if s.chars().count() <= max_chars {
815 return s.to_string();
816 }
817 let truncated: String = s.chars().take(max_chars).collect();
818 format!("{}...", truncated)
819}
820
821fn is_valid_sha256_hex(s: &str) -> bool {
823 s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit())
824}
825
826fn is_valid_uuid(s: &str) -> bool {
828 let parts: Vec<&str> = s.split('-').collect();
829 if parts.len() != 5 {
830 return false;
831 }
832 let lengths = [8, 4, 4, 4, 12];
833 parts
834 .iter()
835 .zip(lengths.iter())
836 .all(|(part, &want)| part.len() == want && part.bytes().all(|b| b.is_ascii_hexdigit()))
837}
838
839fn convert_search_result_to_batch_info(patch: PatchSearchResult) -> BatchPatchInfo {
842 let mut cve_ids: Vec<String> = Vec::new();
843 let mut ghsa_ids: Vec<String> = Vec::new();
844 let mut highest_severity: Option<String> = None;
845 let mut title = String::new();
846
847 let mut seen_cves: HashSet<String> = HashSet::new();
848
849 let mut entries: Vec<(&String, &VulnerabilityResponse)> =
853 patch.vulnerabilities.iter().collect();
854 entries.sort_by(|a, b| a.0.cmp(b.0));
855
856 for (ghsa_id, vuln) in entries {
857 ghsa_ids.push(ghsa_id.clone());
858
859 for cve in &vuln.cves {
860 if seen_cves.insert(cve.clone()) {
861 cve_ids.push(cve.clone());
862 }
863 }
864
865 let current_order = get_severity_order(highest_severity.as_deref());
867 let vuln_order = get_severity_order(Some(&vuln.severity));
868 if vuln_order < current_order {
869 highest_severity = Some(vuln.severity.clone());
870 }
871
872 if title.is_empty() && !vuln.summary.is_empty() {
874 title = truncate_to_chars(&vuln.summary, 97);
875 }
876 }
877
878 if title.is_empty() && !patch.description.is_empty() {
880 title = truncate_to_chars(&patch.description, 97);
881 }
882
883 cve_ids.sort();
884 ghsa_ids.sort();
885
886 BatchPatchInfo {
887 uuid: patch.uuid,
888 purl: patch.purl,
889 tier: patch.tier,
890 cve_ids,
891 ghsa_ids,
892 severity: highest_severity,
893 title,
894 }
895}
896
897fn assemble_batch_from_individual(
909 results: Vec<(String, Option<SearchResponse>)>,
910) -> BatchSearchResponse {
911 let mut packages: Vec<BatchPackagePatches> = Vec::new();
912 let mut can_access_paid_patches = false;
913
914 for (purl, response) in results {
915 let Some(response) = response else { continue };
916
917 if response.can_access_paid_patches {
918 can_access_paid_patches = true;
919 }
920
921 if response.patches.is_empty() {
922 continue;
923 }
924
925 let batch_patches: Vec<BatchPatchInfo> = response
926 .patches
927 .into_iter()
928 .map(convert_search_result_to_batch_info)
929 .collect();
930
931 packages.push(BatchPackagePatches {
932 purl,
933 patches: batch_patches,
934 });
935 }
936
937 BatchSearchResponse {
938 packages,
939 can_access_paid_patches,
940 }
941}
942
943#[derive(Debug, thiserror::Error)]
947pub enum ApiError {
948 #[error("{0}")]
949 Network(String),
950
951 #[error("{0}")]
952 Parse(String),
953
954 #[error("{0}")]
955 Unauthorized(String),
956
957 #[error("{0}")]
958 Forbidden(String),
959
960 #[error("{0}")]
961 RateLimited(String),
962
963 #[error("{0}")]
964 InvalidHash(String),
965
966 #[error("{0}")]
967 Other(String),
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973 use std::collections::HashMap;
974
975 #[test]
976 fn test_urlencoding_basic() {
977 assert_eq!(urlencoding_encode("hello"), "hello");
978 assert_eq!(urlencoding_encode("a b"), "a%20b");
979 assert_eq!(
980 urlencoding_encode("pkg:npm/lodash@4.17.21"),
981 "pkg%3Anpm%2Flodash%404.17.21"
982 );
983 }
984
985 #[test]
986 fn test_is_valid_sha256_hex() {
987 let valid = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
988 assert!(is_valid_sha256_hex(valid));
989
990 assert!(!is_valid_sha256_hex("abcdef"));
992 assert!(!is_valid_sha256_hex(
994 "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"
995 ));
996 }
997
998 #[test]
999 fn test_severity_order() {
1000 assert!(get_severity_order(Some("critical")) < get_severity_order(Some("high")));
1001 assert!(get_severity_order(Some("high")) < get_severity_order(Some("medium")));
1002 assert!(get_severity_order(Some("medium")) < get_severity_order(Some("low")));
1003 assert!(get_severity_order(Some("low")) < get_severity_order(None));
1004 assert_eq!(
1005 get_severity_order(Some("unknown")),
1006 get_severity_order(None)
1007 );
1008 }
1009
1010 #[test]
1011 fn test_convert_search_result_to_batch_info() {
1012 let mut vulns = HashMap::new();
1013 vulns.insert(
1014 "GHSA-1234-5678-9abc".to_string(),
1015 VulnerabilityResponse {
1016 cves: vec!["CVE-2024-0001".into()],
1017 summary: "Test vulnerability".into(),
1018 severity: "high".into(),
1019 description: "A test vuln".into(),
1020 },
1021 );
1022
1023 let patch = PatchSearchResult {
1024 uuid: "uuid-1".into(),
1025 purl: "pkg:npm/test@1.0.0".into(),
1026 published_at: "2024-01-01".into(),
1027 description: "A patch".into(),
1028 license: "MIT".into(),
1029 tier: "free".into(),
1030 vulnerabilities: vulns,
1031 };
1032
1033 let info = convert_search_result_to_batch_info(patch);
1034 assert_eq!(info.uuid, "uuid-1");
1035 assert_eq!(info.cve_ids, vec!["CVE-2024-0001"]);
1036 assert_eq!(info.ghsa_ids, vec!["GHSA-1234-5678-9abc"]);
1037 assert_eq!(info.severity, Some("high".into()));
1038 assert_eq!(info.title, "Test vulnerability");
1039 }
1040
1041 #[tokio::test]
1042 async fn test_get_api_client_from_env_no_token() {
1043 std::env::remove_var("SOCKET_API_TOKEN");
1045 let (client, is_public) = get_api_client_from_env(None).await;
1046 assert!(is_public);
1047 assert!(client.use_public_proxy);
1048 }
1049
1050 fn make_vuln(summary: &str, severity: &str, cves: Vec<&str>) -> VulnerabilityResponse {
1053 VulnerabilityResponse {
1054 cves: cves.into_iter().map(String::from).collect(),
1055 summary: summary.into(),
1056 severity: severity.into(),
1057 description: "desc".into(),
1058 }
1059 }
1060
1061 fn make_patch(
1062 vulns: HashMap<String, VulnerabilityResponse>,
1063 description: &str,
1064 ) -> PatchSearchResult {
1065 PatchSearchResult {
1066 uuid: "uuid-1".into(),
1067 purl: "pkg:npm/test@1.0.0".into(),
1068 published_at: "2024-01-01".into(),
1069 description: description.into(),
1070 license: "MIT".into(),
1071 tier: "free".into(),
1072 vulnerabilities: vulns,
1073 }
1074 }
1075
1076 #[test]
1077 fn test_convert_no_vulnerabilities() {
1078 let patch = make_patch(HashMap::new(), "A patch description");
1079 let info = convert_search_result_to_batch_info(patch);
1080 assert!(info.cve_ids.is_empty());
1081 assert!(info.ghsa_ids.is_empty());
1082 assert_eq!(info.title, "A patch description");
1083 assert!(info.severity.is_none());
1084 }
1085
1086 #[test]
1087 fn test_convert_multiple_vulns_picks_highest_severity() {
1088 let mut vulns = HashMap::new();
1089 vulns.insert(
1090 "GHSA-1111".into(),
1091 make_vuln("Medium vuln", "medium", vec!["CVE-2024-0001"]),
1092 );
1093 vulns.insert(
1094 "GHSA-2222".into(),
1095 make_vuln("Critical vuln", "critical", vec!["CVE-2024-0002"]),
1096 );
1097 let patch = make_patch(vulns, "desc");
1098 let info = convert_search_result_to_batch_info(patch);
1099 assert_eq!(info.severity, Some("critical".into()));
1100 }
1101
1102 #[test]
1103 fn test_convert_duplicate_cves_deduplicated() {
1104 let mut vulns = HashMap::new();
1105 vulns.insert(
1106 "GHSA-1111".into(),
1107 make_vuln("Vuln A", "high", vec!["CVE-2024-0001"]),
1108 );
1109 vulns.insert(
1110 "GHSA-2222".into(),
1111 make_vuln("Vuln B", "high", vec!["CVE-2024-0001"]),
1112 );
1113 let patch = make_patch(vulns, "desc");
1114 let info = convert_search_result_to_batch_info(patch);
1115 let cve_count = info
1117 .cve_ids
1118 .iter()
1119 .filter(|c| *c == "CVE-2024-0001")
1120 .count();
1121 assert_eq!(cve_count, 1);
1122 }
1123
1124 #[test]
1125 fn test_convert_title_truncated_at_100() {
1126 let long_summary = "x".repeat(150);
1127 let mut vulns = HashMap::new();
1128 vulns.insert("GHSA-1111".into(), make_vuln(&long_summary, "high", vec![]));
1129 let patch = make_patch(vulns, "desc");
1130 let info = convert_search_result_to_batch_info(patch);
1131 assert_eq!(info.title.len(), 100);
1133 assert!(info.title.ends_with("..."));
1134 }
1135
1136 #[test]
1137 fn test_convert_title_unicode_truncation() {
1138 let emoji_summary = "\u{1F600}".repeat(30);
1141 let mut vulns = HashMap::new();
1142 vulns.insert(
1143 "GHSA-1111".into(),
1144 make_vuln(&emoji_summary, "high", vec![]),
1145 );
1146 let patch = make_patch(vulns, "desc");
1147 let info = convert_search_result_to_batch_info(patch);
1149 assert!(!info.title.is_empty());
1150
1151 let patch2 = make_patch(HashMap::new(), &"\u{1F600}".repeat(120));
1153 let info2 = convert_search_result_to_batch_info(patch2);
1154 assert!(info2.title.ends_with("..."));
1155 }
1156
1157 #[test]
1158 fn test_convert_title_falls_back_to_description() {
1159 let mut vulns = HashMap::new();
1160 vulns.insert("GHSA-1111".into(), make_vuln("", "high", vec![]));
1161 let patch = make_patch(vulns, "Fallback desc");
1162 let info = convert_search_result_to_batch_info(patch);
1163 assert_eq!(info.title, "Fallback desc");
1164 }
1165
1166 #[test]
1167 fn test_convert_empty_summary_and_description() {
1168 let mut vulns = HashMap::new();
1169 vulns.insert("GHSA-1111".into(), make_vuln("", "high", vec![]));
1170 let patch = make_patch(vulns, "");
1171 let info = convert_search_result_to_batch_info(patch);
1172 assert!(info.title.is_empty());
1173 }
1174
1175 #[test]
1176 fn test_convert_cves_and_ghsas_sorted() {
1177 let mut vulns = HashMap::new();
1178 vulns.insert(
1179 "GHSA-cccc".into(),
1180 make_vuln("V1", "high", vec!["CVE-2024-0003"]),
1181 );
1182 vulns.insert(
1183 "GHSA-aaaa".into(),
1184 make_vuln("V2", "high", vec!["CVE-2024-0001"]),
1185 );
1186 vulns.insert(
1187 "GHSA-bbbb".into(),
1188 make_vuln("V3", "high", vec!["CVE-2024-0002"]),
1189 );
1190 let patch = make_patch(vulns, "desc");
1191 let info = convert_search_result_to_batch_info(patch);
1192 let mut sorted_cves = info.cve_ids.clone();
1194 sorted_cves.sort();
1195 assert_eq!(info.cve_ids, sorted_cves);
1196 let mut sorted_ghsas = info.ghsa_ids.clone();
1197 sorted_ghsas.sort();
1198 assert_eq!(info.ghsa_ids, sorted_ghsas);
1199 }
1200
1201 #[test]
1204 fn test_urlencoding_unicode() {
1205 let encoded = urlencoding_encode("café");
1207 assert_eq!(encoded, "caf%C3%A9");
1208 }
1209
1210 #[test]
1211 fn test_urlencoding_empty() {
1212 assert_eq!(urlencoding_encode(""), "");
1213 }
1214
1215 #[test]
1216 fn test_urlencoding_all_safe_chars() {
1217 let safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
1219 assert_eq!(urlencoding_encode(safe), safe);
1220 }
1221
1222 #[test]
1223 fn test_urlencoding_slash_and_at() {
1224 assert_eq!(urlencoding_encode("/"), "%2F");
1225 assert_eq!(urlencoding_encode("@"), "%40");
1226 }
1227
1228 #[test]
1229 fn test_sha256_uppercase_valid() {
1230 let upper = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789";
1231 assert!(is_valid_sha256_hex(upper));
1232 }
1233
1234 #[test]
1235 fn test_sha256_65_chars_invalid() {
1236 let too_long = "a".repeat(65);
1237 assert!(!is_valid_sha256_hex(&too_long));
1238 }
1239
1240 #[test]
1241 fn test_sha256_63_chars_invalid() {
1242 let too_short = "a".repeat(63);
1243 assert!(!is_valid_sha256_hex(&too_short));
1244 }
1245
1246 #[test]
1247 fn test_sha256_empty_invalid() {
1248 assert!(!is_valid_sha256_hex(""));
1249 }
1250
1251 #[test]
1252 fn test_sha256_mixed_case_valid() {
1253 let mixed = "aAbBcCdDeEfF0123456789aAbBcCdDeEfF0123456789aAbBcCdDeEfF01234567";
1254 assert_eq!(mixed.len(), 64);
1255 assert!(is_valid_sha256_hex(mixed));
1256 }
1257
1258 #[test]
1261 fn test_is_valid_uuid_accepts_standard_form() {
1262 assert!(is_valid_uuid("80630680-4da6-45f9-bba8-b888e0ffd58c"));
1263 assert!(is_valid_uuid("00000000-0000-0000-0000-000000000000"));
1264 assert!(is_valid_uuid("ABCDEF01-2345-6789-ABCD-EF0123456789"));
1266 }
1267
1268 #[test]
1269 fn test_is_valid_uuid_rejects_malformed() {
1270 assert!(!is_valid_uuid(""));
1271 assert!(!is_valid_uuid("not-a-uuid"));
1272 assert!(!is_valid_uuid("80630680-4da6-45f9-bba8"));
1274 assert!(!is_valid_uuid("8063068-4da6-45f9-bba8-b888e0ffd58c"));
1276 assert!(!is_valid_uuid("80630680-4da6-45f9-bba8-b888e0ffd58z"));
1278 assert!(!is_valid_uuid("80630680xxxxx"));
1280 }
1281
1282 #[tokio::test]
1290 async fn test_fetch_diff_rejects_invalid_uuid() {
1291 std::env::remove_var("SOCKET_API_TOKEN");
1292 let (client, _) = get_api_client_from_env(None).await;
1293 let result = client.fetch_diff("not-a-uuid").await;
1294 assert!(matches!(result, Err(ApiError::InvalidHash(_))));
1295 }
1296
1297 #[tokio::test]
1298 async fn test_fetch_package_rejects_invalid_uuid() {
1299 std::env::remove_var("SOCKET_API_TOKEN");
1300 let (client, _) = get_api_client_from_env(None).await;
1301 let result = client.fetch_package("xxx").await;
1302 assert!(matches!(result, Err(ApiError::InvalidHash(_))));
1303 }
1304
1305 #[test]
1308 fn validate_token_shape_accepts_canonical_api_token() {
1309 let raw = format!("sktsec_{}_api", "x".repeat(44));
1312 assert_eq!(raw.len(), 55);
1313 assert!(validate_token_shape(&raw).is_none());
1314 }
1315
1316 #[test]
1317 fn validate_token_shape_accepts_agent_token() {
1318 let raw = format!("sktsec_{}_agent", "x".repeat(44));
1319 assert!(validate_token_shape(&raw).is_none());
1320 }
1321
1322 #[test]
1323 fn validate_token_shape_flags_sha512_hash() {
1324 let hash = "sha512-7aegAloeNsCqF1mpNL2J9MJ2dpIxQEwgKvXPml8XY2rrV2Za+\
1325 bfj0yhG7RcqvqqLZ4iAH/drJjHjOqFkTGhddg==";
1326 let msg = validate_token_shape(hash).expect("hash must be flagged");
1327 assert!(
1328 msg.contains("does not look like a Socket API token"),
1329 "missing core warning; got: {msg}"
1330 );
1331 assert!(
1332 msg.contains("SRI-format hash"),
1333 "missing sha-hash hint; got: {msg}"
1334 );
1335 assert!(
1336 msg.contains("sktsec_"),
1337 "warning must point users at the correct prefix; got: {msg}"
1338 );
1339 assert!(
1341 !msg.contains("7RcqvqqLZ4iAH"),
1342 "middle of the value must be redacted; got: {msg}"
1343 );
1344 }
1345
1346 #[test]
1347 fn validate_token_shape_flags_too_short() {
1348 let msg = validate_token_shape("sktsec_abc_api").expect("short token must be flagged");
1349 assert!(msg.contains("does not look like a Socket API token"));
1350 assert!(!msg.contains("SRI-format hash"));
1351 }
1352
1353 #[test]
1354 fn validate_token_shape_flags_missing_suffix() {
1355 let raw = format!("sktsec_{}", "x".repeat(50));
1356 assert!(validate_token_shape(&raw).is_some());
1357 }
1358
1359 #[test]
1360 fn looks_like_token_hash_recognizes_sri_prefixes() {
1361 assert!(looks_like_token_hash("sha256-abc"));
1362 assert!(looks_like_token_hash("sha384-abc"));
1363 assert!(looks_like_token_hash("sha512-abc"));
1364 assert!(!looks_like_token_hash("sktsec_xxx_api"));
1365 assert!(!looks_like_token_hash("hello"));
1366 assert!(!looks_like_token_hash(""));
1367 }
1368
1369 fn proxy_client(api_url: &str) -> ApiClient {
1377 ApiClient::new(ApiClientOptions {
1378 api_url: api_url.into(),
1379 api_token: None,
1380 use_public_proxy: true,
1381 org_slug: None,
1382 })
1383 }
1384
1385 #[test]
1386 fn binary_url_proxy_uses_configured_api_url() {
1387 let client = proxy_client("https://custom.proxy.example");
1388 let (url, use_auth) = client.binary_url("blob", "deadbeef");
1389 assert!(!use_auth);
1390 assert_eq!(url, "https://custom.proxy.example/patch/blob/deadbeef");
1391 }
1392
1393 #[test]
1394 fn binary_url_proxy_covers_diff_and_package() {
1395 let client = proxy_client("https://custom.proxy.example");
1396 assert_eq!(
1397 client.binary_url("diff", "uuid-1").0,
1398 "https://custom.proxy.example/patch/diff/uuid-1"
1399 );
1400 assert_eq!(
1401 client.binary_url("package", "uuid-1").0,
1402 "https://custom.proxy.example/patch/package/uuid-1"
1403 );
1404 }
1405
1406 #[test]
1407 fn binary_url_proxy_trims_trailing_slash() {
1408 let client = proxy_client("https://custom.proxy.example/");
1411 assert_eq!(
1412 client.binary_url("blob", "x").0,
1413 "https://custom.proxy.example/patch/blob/x"
1414 );
1415 }
1416
1417 #[test]
1418 fn binary_url_authenticated_uses_org_path() {
1419 let client = ApiClient::new(ApiClientOptions {
1420 api_url: "https://api.socket.dev".into(),
1421 api_token: Some("sktsec_x_api".into()),
1422 use_public_proxy: false,
1423 org_slug: Some("my-org".into()),
1424 });
1425 let (url, use_auth) = client.binary_url("diff", "uuid-123");
1426 assert!(use_auth);
1427 assert_eq!(
1428 url,
1429 "https://api.socket.dev/v0/orgs/my-org/patches/diff/uuid-123"
1430 );
1431 }
1432
1433 fn org(slug: &str) -> crate::api::types::OrganizationInfo {
1436 crate::api::types::OrganizationInfo {
1437 id: format!("id-{slug}"),
1438 name: Some(slug.to_string()),
1439 image: None,
1440 plan: "free".into(),
1441 slug: slug.into(),
1442 }
1443 }
1444
1445 #[test]
1446 fn select_org_slug_errors_when_empty() {
1447 assert!(matches!(select_org_slug(vec![]), Err(ApiError::Other(_))));
1448 }
1449
1450 #[test]
1451 fn select_org_slug_returns_sole_org() {
1452 assert_eq!(select_org_slug(vec![org("acme")]).unwrap(), "acme");
1453 }
1454
1455 #[test]
1456 fn select_org_slug_is_deterministic_for_multiple() {
1457 let a = select_org_slug(vec![org("zeta"), org("alpha"), org("mid")]).unwrap();
1460 let b = select_org_slug(vec![org("mid"), org("zeta"), org("alpha")]).unwrap();
1461 assert_eq!(a, "alpha");
1462 assert_eq!(b, "alpha");
1463 }
1464
1465 fn search_response(
1468 purl: &str,
1469 can_access_paid_patches: bool,
1470 patch_uuids: &[&str],
1471 ) -> SearchResponse {
1472 SearchResponse {
1473 patches: patch_uuids
1474 .iter()
1475 .map(|uuid| PatchSearchResult {
1476 uuid: (*uuid).into(),
1477 purl: purl.into(),
1478 published_at: "2024-01-01".into(),
1479 description: "desc".into(),
1480 license: "MIT".into(),
1481 tier: "free".into(),
1482 vulnerabilities: HashMap::new(),
1483 })
1484 .collect(),
1485 can_access_paid_patches,
1486 }
1487 }
1488
1489 #[test]
1490 fn assemble_batch_collects_patches_per_purl() {
1491 let results = vec![
1492 (
1493 "pkg:npm/a@1".to_string(),
1494 Some(search_response("pkg:npm/a@1", false, &["uuid-a"])),
1495 ),
1496 (
1497 "pkg:npm/b@1".to_string(),
1498 Some(search_response(
1499 "pkg:npm/b@1",
1500 false,
1501 &["uuid-b1", "uuid-b2"],
1502 )),
1503 ),
1504 ];
1505 let batch = assemble_batch_from_individual(results);
1506 assert_eq!(batch.packages.len(), 2);
1507 assert!(!batch.can_access_paid_patches);
1508 let a = batch
1509 .packages
1510 .iter()
1511 .find(|p| p.purl == "pkg:npm/a@1")
1512 .unwrap();
1513 assert_eq!(a.patches.len(), 1);
1514 let b = batch
1515 .packages
1516 .iter()
1517 .find(|p| p.purl == "pkg:npm/b@1")
1518 .unwrap();
1519 assert_eq!(b.patches.len(), 2);
1520 }
1521
1522 #[test]
1523 fn assemble_batch_skips_errored_and_empty_responses() {
1524 let results = vec![
1526 ("pkg:npm/err@1".to_string(), None),
1527 (
1528 "pkg:npm/empty@1".to_string(),
1529 Some(search_response("pkg:npm/empty@1", false, &[])),
1530 ),
1531 (
1532 "pkg:npm/ok@1".to_string(),
1533 Some(search_response("pkg:npm/ok@1", false, &["uuid-ok"])),
1534 ),
1535 ];
1536 let batch = assemble_batch_from_individual(results);
1537 assert_eq!(batch.packages.len(), 1);
1539 assert_eq!(batch.packages[0].purl, "pkg:npm/ok@1");
1540 }
1541
1542 #[test]
1543 fn assemble_batch_aggregates_paid_flag_across_all_responses() {
1544 let results = vec![
1546 (
1547 "pkg:npm/a@1".to_string(),
1548 Some(search_response("pkg:npm/a@1", false, &["uuid-a"])),
1549 ),
1550 (
1551 "pkg:npm/b@1".to_string(),
1552 Some(search_response("pkg:npm/b@1", true, &["uuid-b"])),
1553 ),
1554 ];
1555 let batch = assemble_batch_from_individual(results);
1556 assert!(batch.can_access_paid_patches);
1557 }
1558
1559 #[test]
1560 fn assemble_batch_keeps_paid_flag_from_empty_patch_response() {
1561 let results = vec![
1566 (
1567 "pkg:npm/free@1".to_string(),
1568 Some(search_response("pkg:npm/free@1", false, &["uuid-free"])),
1569 ),
1570 (
1571 "pkg:npm/paid-only@1".to_string(),
1572 Some(search_response("pkg:npm/paid-only@1", true, &[])),
1573 ),
1574 ];
1575 let batch = assemble_batch_from_individual(results);
1576 assert!(
1577 batch.can_access_paid_patches,
1578 "paid-access flag from an empty-patch response was dropped"
1579 );
1580 assert_eq!(batch.packages.len(), 1);
1582 assert_eq!(batch.packages[0].purl, "pkg:npm/free@1");
1583 }
1584
1585 #[test]
1588 fn test_convert_title_deterministic_across_iteration_order() {
1589 let mut vulns = HashMap::new();
1593 vulns.insert("GHSA-zzzz".into(), make_vuln("Z summary", "high", vec![]));
1594 vulns.insert("GHSA-aaaa".into(), make_vuln("A summary", "high", vec![]));
1595 let patch = make_patch(vulns, "desc");
1596 let info = convert_search_result_to_batch_info(patch);
1597 assert_eq!(info.title, "A summary");
1598 }
1599}