1use std::collections::HashSet;
2
3use reqwest::header::{self, HeaderMap, HeaderValue};
4use reqwest::StatusCode;
5use serde::Serialize;
6
7use crate::api::types::*;
8use crate::constants::{
9 DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT as USER_AGENT_VALUE,
10};
11use crate::utils::env_compat::read_env_with_legacy;
12
13fn is_debug_enabled() -> bool {
16 match read_env_with_legacy("SOCKET_DEBUG", "SOCKET_PATCH_DEBUG") {
17 Some(val) => val == "1" || val == "true",
18 None => false,
19 }
20}
21
22fn debug_log(message: &str) {
24 if is_debug_enabled() {
25 eprintln!("[socket-patch debug] {}", message);
26 }
27}
28
29fn get_severity_order(severity: Option<&str>) -> u8 {
31 match severity.map(|s| s.to_lowercase()).as_deref() {
32 Some("critical") => 0,
33 Some("high") => 1,
34 Some("medium") => 2,
35 Some("low") => 3,
36 _ => 4,
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ApiClientOptions {
43 pub api_url: String,
44 pub api_token: Option<String>,
45 pub use_public_proxy: bool,
48 pub org_slug: Option<String>,
51}
52
53#[derive(Debug, Clone)]
59pub struct ApiClient {
60 client: reqwest::Client,
61 api_url: String,
62 api_token: Option<String>,
63 use_public_proxy: bool,
64 org_slug: Option<String>,
65}
66
67#[derive(Serialize)]
69struct BatchSearchBody {
70 components: Vec<BatchComponent>,
71}
72
73#[derive(Serialize)]
74struct BatchComponent {
75 purl: String,
76}
77
78impl ApiClient {
79 pub fn new(options: ApiClientOptions) -> Self {
84 let api_url = options.api_url.trim_end_matches('/').to_string();
85
86 let mut default_headers = HeaderMap::new();
87 default_headers.insert(
88 header::USER_AGENT,
89 HeaderValue::from_static(USER_AGENT_VALUE),
90 );
91 default_headers.insert(
92 header::ACCEPT,
93 HeaderValue::from_static("application/json"),
94 );
95
96 if let Some(ref token) = options.api_token {
97 if let Ok(hv) = HeaderValue::from_str(&format!("Bearer {}", token)) {
98 default_headers.insert(header::AUTHORIZATION, hv);
99 }
100 }
101
102 let client = reqwest::Client::builder()
103 .default_headers(default_headers)
104 .build()
105 .expect("failed to build reqwest client");
106
107 Self {
108 client,
109 api_url,
110 api_token: options.api_token,
111 use_public_proxy: options.use_public_proxy,
112 org_slug: options.org_slug,
113 }
114 }
115
116 pub fn api_token(&self) -> Option<&String> {
118 self.api_token.as_ref()
119 }
120
121 pub fn org_slug(&self) -> Option<&String> {
123 self.org_slug.as_ref()
124 }
125
126 async fn get_json<T: serde::de::DeserializeOwned>(
130 &self,
131 path: &str,
132 ) -> Result<Option<T>, ApiError> {
133 let url = format!("{}{}", self.api_url, path);
134 debug_log(&format!("GET {}", url));
135
136 let resp = self
137 .client
138 .get(&url)
139 .send()
140 .await
141 .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
142
143 Self::handle_json_response(resp, self.use_public_proxy).await
144 }
145
146 async fn post_json<T: serde::de::DeserializeOwned, B: Serialize>(
148 &self,
149 path: &str,
150 body: &B,
151 ) -> Result<Option<T>, ApiError> {
152 let url = format!("{}{}", self.api_url, path);
153 debug_log(&format!("POST {}", url));
154
155 let resp = self
156 .client
157 .post(&url)
158 .header(header::CONTENT_TYPE, "application/json")
159 .json(body)
160 .send()
161 .await
162 .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
163
164 Self::handle_json_response(resp, self.use_public_proxy).await
165 }
166
167 async fn handle_json_response<T: serde::de::DeserializeOwned>(
169 resp: reqwest::Response,
170 use_public_proxy: bool,
171 ) -> Result<Option<T>, ApiError> {
172 let status = resp.status();
173
174 match status {
175 StatusCode::OK => {
176 let body = resp
177 .json::<T>()
178 .await
179 .map_err(|e| ApiError::Parse(format!("Failed to parse response: {}", e)))?;
180 Ok(Some(body))
181 }
182 StatusCode::NOT_FOUND => Ok(None),
183 StatusCode::UNAUTHORIZED => {
184 Err(ApiError::Unauthorized("Unauthorized: Invalid API token".into()))
185 }
186 StatusCode::FORBIDDEN => {
187 let msg = if use_public_proxy {
188 "Forbidden: This patch is only available to paid subscribers. \
189 Sign up at https://socket.dev to access paid patches."
190 } else {
191 "Forbidden: Access denied. This may be a paid patch or \
192 you may not have access to this organization."
193 };
194 Err(ApiError::Forbidden(msg.into()))
195 }
196 StatusCode::TOO_MANY_REQUESTS => {
197 Err(ApiError::RateLimited(
198 "Rate limit exceeded. Please try again later.".into(),
199 ))
200 }
201 _ => {
202 let text = resp.text().await.unwrap_or_default();
203 Err(ApiError::Other(format!(
204 "API request failed with status {}: {}",
205 status.as_u16(),
206 text
207 )))
208 }
209 }
210 }
211
212 pub async fn fetch_patch(
218 &self,
219 org_slug: Option<&str>,
220 uuid: &str,
221 ) -> Result<Option<PatchResponse>, ApiError> {
222 let path = if self.use_public_proxy {
223 format!("/patch/view/{}", uuid)
224 } else {
225 let slug = org_slug
226 .or(self.org_slug.as_deref())
227 .unwrap_or("default");
228 format!("/v0/orgs/{}/patches/view/{}", slug, uuid)
229 };
230 self.get_json(&path).await
231 }
232
233 pub async fn search_patches_by_cve(
235 &self,
236 org_slug: Option<&str>,
237 cve_id: &str,
238 ) -> Result<SearchResponse, ApiError> {
239 let encoded = urlencoding_encode(cve_id);
240 let path = if self.use_public_proxy {
241 format!("/patch/by-cve/{}", encoded)
242 } else {
243 let slug = org_slug
244 .or(self.org_slug.as_deref())
245 .unwrap_or("default");
246 format!("/v0/orgs/{}/patches/by-cve/{}", slug, encoded)
247 };
248 let result = self.get_json::<SearchResponse>(&path).await?;
249 Ok(result.unwrap_or_else(|| SearchResponse {
250 patches: Vec::new(),
251 can_access_paid_patches: false,
252 }))
253 }
254
255 pub async fn search_patches_by_ghsa(
257 &self,
258 org_slug: Option<&str>,
259 ghsa_id: &str,
260 ) -> Result<SearchResponse, ApiError> {
261 let encoded = urlencoding_encode(ghsa_id);
262 let path = if self.use_public_proxy {
263 format!("/patch/by-ghsa/{}", encoded)
264 } else {
265 let slug = org_slug
266 .or(self.org_slug.as_deref())
267 .unwrap_or("default");
268 format!("/v0/orgs/{}/patches/by-ghsa/{}", slug, encoded)
269 };
270 let result = self.get_json::<SearchResponse>(&path).await?;
271 Ok(result.unwrap_or_else(|| SearchResponse {
272 patches: Vec::new(),
273 can_access_paid_patches: false,
274 }))
275 }
276
277 pub async fn search_patches_by_package(
282 &self,
283 org_slug: Option<&str>,
284 purl: &str,
285 ) -> Result<SearchResponse, ApiError> {
286 let encoded = urlencoding_encode(purl);
287 let path = if self.use_public_proxy {
288 format!("/patch/by-package/{}", encoded)
289 } else {
290 let slug = org_slug
291 .or(self.org_slug.as_deref())
292 .unwrap_or("default");
293 format!("/v0/orgs/{}/patches/by-package/{}", slug, encoded)
294 };
295 let result = self.get_json::<SearchResponse>(&path).await?;
296 Ok(result.unwrap_or_else(|| SearchResponse {
297 patches: Vec::new(),
298 can_access_paid_patches: false,
299 }))
300 }
301
302 pub async fn search_patches_batch(
311 &self,
312 org_slug: Option<&str>,
313 purls: &[String],
314 ) -> Result<BatchSearchResponse, ApiError> {
315 if !self.use_public_proxy {
316 let slug = org_slug
317 .or(self.org_slug.as_deref())
318 .unwrap_or("default");
319 let path = format!("/v0/orgs/{}/patches/batch", slug);
320 let body = BatchSearchBody {
321 components: purls
322 .iter()
323 .map(|p| BatchComponent { purl: p.clone() })
324 .collect(),
325 };
326 let result = self.post_json::<BatchSearchResponse, _>(&path, &body).await?;
327 return Ok(result.unwrap_or_else(|| BatchSearchResponse {
328 packages: Vec::new(),
329 can_access_paid_patches: false,
330 }));
331 }
332
333 self.search_patches_batch_via_individual_queries(purls).await
335 }
336
337 async fn search_patches_batch_via_individual_queries(
343 &self,
344 purls: &[String],
345 ) -> Result<BatchSearchResponse, ApiError> {
346 const CONCURRENCY_LIMIT: usize = 10;
347
348 let mut packages: Vec<BatchPackagePatches> = Vec::new();
349 let mut can_access_paid_patches = false;
350
351 let mut all_results: Vec<(String, Option<SearchResponse>)> = Vec::new();
353
354 for chunk in purls.chunks(CONCURRENCY_LIMIT) {
355 let mut join_set = tokio::task::JoinSet::new();
357
358 for purl in chunk {
359 let purl = purl.clone();
360 let client = self.clone();
361 join_set.spawn(async move {
362 let resp = client.search_patches_by_package(None, &purl).await;
363 match resp {
364 Ok(r) => (purl, Some(r)),
365 Err(e) => {
366 debug_log(&format!("Error fetching patches for {}: {}", purl, e));
367 (purl, None)
368 }
369 }
370 });
371 }
372
373 while let Some(result) = join_set.join_next().await {
374 match result {
375 Ok(pair) => all_results.push(pair),
376 Err(e) => {
377 debug_log(&format!("Task join error: {}", e));
378 }
379 }
380 }
381 }
382
383 for (purl, response) in all_results {
385 let response = match response {
386 Some(r) if !r.patches.is_empty() => r,
387 _ => continue,
388 };
389
390 if response.can_access_paid_patches {
391 can_access_paid_patches = true;
392 }
393
394 let batch_patches: Vec<BatchPatchInfo> = response
395 .patches
396 .into_iter()
397 .map(convert_search_result_to_batch_info)
398 .collect();
399
400 packages.push(BatchPackagePatches {
401 purl,
402 patches: batch_patches,
403 });
404 }
405
406 Ok(BatchSearchResponse {
407 packages,
408 can_access_paid_patches,
409 })
410 }
411
412 pub async fn fetch_organizations(
414 &self,
415 ) -> Result<Vec<crate::api::types::OrganizationInfo>, ApiError> {
416 let path = "/v0/organizations";
417 match self
418 .get_json::<crate::api::types::OrganizationsResponse>(path)
419 .await?
420 {
421 Some(resp) => Ok(resp.organizations.into_values().collect()),
422 None => Ok(Vec::new()),
423 }
424 }
425
426 pub async fn resolve_org_slug(&self) -> Result<String, ApiError> {
432 let orgs = self.fetch_organizations().await?;
433 match orgs.len() {
434 0 => Err(ApiError::Other(
435 "No organizations found for this API token.".into(),
436 )),
437 1 => Ok(orgs.into_iter().next().unwrap().slug),
438 _ => {
439 let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect();
440 let first = orgs[0].slug.clone();
441 eprintln!(
442 "Multiple organizations found: {}. Using \"{}\". \
443 Pass --org to select a different one.",
444 slugs.join(", "),
445 first
446 );
447 Ok(first)
448 }
449 }
450 }
451
452 pub async fn fetch_blob(&self, hash: &str) -> Result<Option<Vec<u8>>, ApiError> {
458 if !is_valid_sha256_hex(hash) {
460 return Err(ApiError::InvalidHash(format!(
461 "Invalid hash format: {}. Expected SHA256 hash (64 hex characters).",
462 hash
463 )));
464 }
465 self.fetch_binary("blob", "blob", hash).await
466 }
467
468 pub async fn fetch_diff(&self, uuid: &str) -> Result<Option<Vec<u8>>, ApiError> {
474 if !is_valid_uuid(uuid) {
475 return Err(ApiError::InvalidHash(format!(
476 "Invalid patch UUID: {}",
477 uuid
478 )));
479 }
480 self.fetch_binary("diff", "diff", uuid).await
481 }
482
483 pub async fn fetch_package(&self, uuid: &str) -> Result<Option<Vec<u8>>, ApiError> {
487 if !is_valid_uuid(uuid) {
488 return Err(ApiError::InvalidHash(format!(
489 "Invalid patch UUID: {}",
490 uuid
491 )));
492 }
493 self.fetch_binary("package", "package", uuid).await
494 }
495
496 async fn fetch_binary(
502 &self,
503 kind: &str,
504 label: &str,
505 identifier: &str,
506 ) -> Result<Option<Vec<u8>>, ApiError> {
507 let (url, use_auth) =
508 if self.api_token.is_some() && self.org_slug.is_some() && !self.use_public_proxy {
509 let slug = self.org_slug.as_deref().unwrap();
510 let u = format!(
511 "{}/v0/orgs/{}/patches/{}/{}",
512 self.api_url, slug, kind, identifier
513 );
514 (u, true)
515 } else {
516 let proxy_url =
517 read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
518 .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string());
519 let u = format!(
520 "{}/patch/{}/{}",
521 proxy_url.trim_end_matches('/'),
522 kind,
523 identifier
524 );
525 (u, false)
526 };
527
528 debug_log(&format!("GET {} {}", label, url));
529
530 let resp = if use_auth {
534 self.client
535 .get(&url)
536 .header(header::ACCEPT, "application/octet-stream")
537 .send()
538 .await
539 } else {
540 let mut headers = HeaderMap::new();
541 headers.insert(
542 header::USER_AGENT,
543 HeaderValue::from_static(USER_AGENT_VALUE),
544 );
545 headers.insert(
546 header::ACCEPT,
547 HeaderValue::from_static("application/octet-stream"),
548 );
549
550 let plain_client = reqwest::Client::builder()
551 .default_headers(headers)
552 .build()
553 .expect("failed to build plain reqwest client");
554
555 plain_client.get(&url).send().await
556 };
557
558 let resp = resp.map_err(|e| {
559 ApiError::Network(format!(
560 "Network error fetching {} {}: {}",
561 label, identifier, e
562 ))
563 })?;
564
565 let status = resp.status();
566
567 match status {
568 StatusCode::OK => {
569 let bytes = resp.bytes().await.map_err(|e| {
570 ApiError::Network(format!(
571 "Error reading {} body for {}: {}",
572 label, identifier, e
573 ))
574 })?;
575 Ok(Some(bytes.to_vec()))
576 }
577 StatusCode::NOT_FOUND => Ok(None),
578 _ => {
579 let text = resp.text().await.unwrap_or_default();
580 Err(ApiError::Other(format!(
581 "Failed to fetch {} {}: status {} - {}",
582 label,
583 identifier,
584 status.as_u16(),
585 text,
586 )))
587 }
588 }
589 }
590}
591
592#[derive(Debug, Clone, Default)]
600pub struct ApiClientEnvOverrides {
601 pub api_url: Option<String>,
602 pub api_token: Option<String>,
603 pub org_slug: Option<String>,
604 pub proxy_url: Option<String>,
605}
606
607pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
628 get_api_client_with_overrides(ApiClientEnvOverrides {
629 org_slug: org_slug.map(String::from),
630 ..ApiClientEnvOverrides::default()
631 })
632 .await
633}
634
635pub async fn get_api_client_with_overrides(
641 overrides: ApiClientEnvOverrides,
642) -> (ApiClient, bool) {
643 let api_token = overrides
644 .api_token
645 .or_else(|| std::env::var("SOCKET_API_TOKEN").ok())
646 .filter(|t| !t.is_empty());
647 let resolved_org_slug = overrides
648 .org_slug
649 .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok());
650
651 if api_token.is_none() {
652 let proxy_url = overrides.proxy_url.unwrap_or_else(|| {
653 read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
654 .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
655 });
656 eprintln!(
657 "No SOCKET_API_TOKEN set. Using public patch API proxy (free patches only)."
658 );
659 let client = ApiClient::new(ApiClientOptions {
660 api_url: proxy_url,
661 api_token: None,
662 use_public_proxy: true,
663 org_slug: None,
664 });
665 return (client, true);
666 }
667
668 if let Some(ref t) = api_token {
671 if let Some(msg) = validate_token_shape(t) {
672 eprintln!("{msg}");
673 }
674 }
675
676 let api_url = overrides
677 .api_url
678 .or_else(|| std::env::var("SOCKET_API_URL").ok())
679 .unwrap_or_else(|| DEFAULT_SOCKET_API_URL.to_string());
680
681 let final_org_slug = if resolved_org_slug.is_some() {
683 resolved_org_slug
684 } else {
685 let temp_client = ApiClient::new(ApiClientOptions {
686 api_url: api_url.clone(),
687 api_token: api_token.clone(),
688 use_public_proxy: false,
689 org_slug: None,
690 });
691 match temp_client.resolve_org_slug().await {
692 Ok(slug) => Some(slug),
693 Err(e) => {
694 eprintln!("Warning: Could not auto-detect organization: {e}");
695 if matches!(e, ApiError::Unauthorized(_)) {
696 if let Some(ref t) = api_token {
697 if looks_like_token_hash(t) {
698 eprintln!(
699 " Hint: SOCKET_API_TOKEN starts with `{}-` \
700 which is the stored hash format. Set it to \
701 the raw `sktsec_..._api` value instead.",
702 t.split('-').next().unwrap_or("sha512")
703 );
704 }
705 }
706 }
707 None
708 }
709 }
710 };
711
712 let client = ApiClient::new(ApiClientOptions {
713 api_url,
714 api_token,
715 use_public_proxy: false,
716 org_slug: final_org_slug,
717 });
718 (client, false)
719}
720
721pub fn build_proxy_fallback_client(overrides: &ApiClientEnvOverrides) -> ApiClient {
729 let proxy_url = overrides.proxy_url.clone().unwrap_or_else(|| {
730 read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
731 .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
732 });
733 ApiClient::new(ApiClientOptions {
734 api_url: proxy_url,
735 api_token: None,
736 use_public_proxy: true,
737 org_slug: None,
738 })
739}
740
741pub fn looks_like_token_hash(token: &str) -> bool {
748 matches!(
749 token.split_once('-'),
750 Some(("sha256" | "sha384" | "sha512", _))
751 )
752}
753
754pub fn validate_token_shape(token: &str) -> Option<String> {
770 let has_prefix = token.starts_with("sktsec_");
771 let has_suffix = token.ends_with("_api") || token.ends_with("_agent");
772 let plausible_len = token.len() >= 55;
773 if has_prefix && has_suffix && plausible_len {
774 return None;
775 }
776 let len = token.len();
777 let head: String = token.chars().take(8).collect();
778 let tail_start = len.saturating_sub(4);
779 let tail: String = token.chars().skip(tail_start).collect();
780 let preview = if len <= 12 {
781 token.to_string()
782 } else {
783 format!("{head}...{tail}")
784 };
785 let hash_hint = if looks_like_token_hash(token) {
786 "\n That value looks like an SRI-format hash (sha###-<base64>) — \
787 the server stores the *hash* of your token, not what you should \
788 set here. Use the raw `sktsec_..._api` value shown when the token \
789 was generated."
790 } else {
791 ""
792 };
793 Some(format!(
794 "Warning: SOCKET_API_TOKEN does not look like a Socket API token \
795 (expected `sktsec_<44 chars>_api`).{hash_hint}\n \
796 Got: {preview} ({len} chars). Continuing anyway; the server may \
797 reject this with 401."
798 ))
799}
800
801pub fn is_fallback_candidate(err: &ApiError) -> bool {
806 matches!(err, ApiError::Unauthorized(_) | ApiError::Forbidden(_))
807}
808
809fn urlencoding_encode(input: &str) -> String {
813 let mut out = String::with_capacity(input.len());
815 for byte in input.bytes() {
816 match byte {
817 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
818 out.push(byte as char)
819 }
820 _ => {
821 out.push('%');
822 out.push_str(&format!("{:02X}", byte));
823 }
824 }
825 }
826 out
827}
828
829fn truncate_to_chars(s: &str, max_chars: usize) -> String {
832 if s.chars().count() <= max_chars {
833 return s.to_string();
834 }
835 let truncated: String = s.chars().take(max_chars).collect();
836 format!("{}...", truncated)
837}
838
839fn is_valid_sha256_hex(s: &str) -> bool {
841 s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit())
842}
843
844fn is_valid_uuid(s: &str) -> bool {
846 let parts: Vec<&str> = s.split('-').collect();
847 if parts.len() != 5 {
848 return false;
849 }
850 let lengths = [8, 4, 4, 4, 12];
851 parts
852 .iter()
853 .zip(lengths.iter())
854 .all(|(part, &want)| part.len() == want && part.bytes().all(|b| b.is_ascii_hexdigit()))
855}
856
857fn convert_search_result_to_batch_info(patch: PatchSearchResult) -> BatchPatchInfo {
860 let mut cve_ids: Vec<String> = Vec::new();
861 let mut ghsa_ids: Vec<String> = Vec::new();
862 let mut highest_severity: Option<String> = None;
863 let mut title = String::new();
864
865 let mut seen_cves: HashSet<String> = HashSet::new();
866
867 for (ghsa_id, vuln) in &patch.vulnerabilities {
868 ghsa_ids.push(ghsa_id.clone());
869
870 for cve in &vuln.cves {
871 if seen_cves.insert(cve.clone()) {
872 cve_ids.push(cve.clone());
873 }
874 }
875
876 let current_order = get_severity_order(highest_severity.as_deref());
878 let vuln_order = get_severity_order(Some(&vuln.severity));
879 if vuln_order < current_order {
880 highest_severity = Some(vuln.severity.clone());
881 }
882
883 if title.is_empty() && !vuln.summary.is_empty() {
885 title = truncate_to_chars(&vuln.summary, 97);
886 }
887 }
888
889 if title.is_empty() && !patch.description.is_empty() {
891 title = truncate_to_chars(&patch.description, 97);
892 }
893
894 cve_ids.sort();
895 ghsa_ids.sort();
896
897 BatchPatchInfo {
898 uuid: patch.uuid,
899 purl: patch.purl,
900 tier: patch.tier,
901 cve_ids,
902 ghsa_ids,
903 severity: highest_severity,
904 title,
905 }
906}
907
908#[derive(Debug, thiserror::Error)]
912pub enum ApiError {
913 #[error("{0}")]
914 Network(String),
915
916 #[error("{0}")]
917 Parse(String),
918
919 #[error("{0}")]
920 Unauthorized(String),
921
922 #[error("{0}")]
923 Forbidden(String),
924
925 #[error("{0}")]
926 RateLimited(String),
927
928 #[error("{0}")]
929 InvalidHash(String),
930
931 #[error("{0}")]
932 Other(String),
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938 use std::collections::HashMap;
939
940 #[test]
941 fn test_urlencoding_basic() {
942 assert_eq!(urlencoding_encode("hello"), "hello");
943 assert_eq!(urlencoding_encode("a b"), "a%20b");
944 assert_eq!(
945 urlencoding_encode("pkg:npm/lodash@4.17.21"),
946 "pkg%3Anpm%2Flodash%404.17.21"
947 );
948 }
949
950 #[test]
951 fn test_is_valid_sha256_hex() {
952 let valid = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
953 assert!(is_valid_sha256_hex(valid));
954
955 assert!(!is_valid_sha256_hex("abcdef"));
957 assert!(!is_valid_sha256_hex(
959 "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"
960 ));
961 }
962
963 #[test]
964 fn test_severity_order() {
965 assert!(get_severity_order(Some("critical")) < get_severity_order(Some("high")));
966 assert!(get_severity_order(Some("high")) < get_severity_order(Some("medium")));
967 assert!(get_severity_order(Some("medium")) < get_severity_order(Some("low")));
968 assert!(get_severity_order(Some("low")) < get_severity_order(None));
969 assert_eq!(get_severity_order(Some("unknown")), get_severity_order(None));
970 }
971
972 #[test]
973 fn test_convert_search_result_to_batch_info() {
974 let mut vulns = HashMap::new();
975 vulns.insert(
976 "GHSA-1234-5678-9abc".to_string(),
977 VulnerabilityResponse {
978 cves: vec!["CVE-2024-0001".into()],
979 summary: "Test vulnerability".into(),
980 severity: "high".into(),
981 description: "A test vuln".into(),
982 },
983 );
984
985 let patch = PatchSearchResult {
986 uuid: "uuid-1".into(),
987 purl: "pkg:npm/test@1.0.0".into(),
988 published_at: "2024-01-01".into(),
989 description: "A patch".into(),
990 license: "MIT".into(),
991 tier: "free".into(),
992 vulnerabilities: vulns,
993 };
994
995 let info = convert_search_result_to_batch_info(patch);
996 assert_eq!(info.uuid, "uuid-1");
997 assert_eq!(info.cve_ids, vec!["CVE-2024-0001"]);
998 assert_eq!(info.ghsa_ids, vec!["GHSA-1234-5678-9abc"]);
999 assert_eq!(info.severity, Some("high".into()));
1000 assert_eq!(info.title, "Test vulnerability");
1001 }
1002
1003 #[tokio::test]
1004 async fn test_get_api_client_from_env_no_token() {
1005 std::env::remove_var("SOCKET_API_TOKEN");
1007 let (client, is_public) = get_api_client_from_env(None).await;
1008 assert!(is_public);
1009 assert!(client.use_public_proxy);
1010 }
1011
1012 fn make_vuln(summary: &str, severity: &str, cves: Vec<&str>) -> VulnerabilityResponse {
1015 VulnerabilityResponse {
1016 cves: cves.into_iter().map(String::from).collect(),
1017 summary: summary.into(),
1018 severity: severity.into(),
1019 description: "desc".into(),
1020 }
1021 }
1022
1023 fn make_patch(
1024 vulns: HashMap<String, VulnerabilityResponse>,
1025 description: &str,
1026 ) -> PatchSearchResult {
1027 PatchSearchResult {
1028 uuid: "uuid-1".into(),
1029 purl: "pkg:npm/test@1.0.0".into(),
1030 published_at: "2024-01-01".into(),
1031 description: description.into(),
1032 license: "MIT".into(),
1033 tier: "free".into(),
1034 vulnerabilities: vulns,
1035 }
1036 }
1037
1038 #[test]
1039 fn test_convert_no_vulnerabilities() {
1040 let patch = make_patch(HashMap::new(), "A patch description");
1041 let info = convert_search_result_to_batch_info(patch);
1042 assert!(info.cve_ids.is_empty());
1043 assert!(info.ghsa_ids.is_empty());
1044 assert_eq!(info.title, "A patch description");
1045 assert!(info.severity.is_none());
1046 }
1047
1048 #[test]
1049 fn test_convert_multiple_vulns_picks_highest_severity() {
1050 let mut vulns = HashMap::new();
1051 vulns.insert(
1052 "GHSA-1111".into(),
1053 make_vuln("Medium vuln", "medium", vec!["CVE-2024-0001"]),
1054 );
1055 vulns.insert(
1056 "GHSA-2222".into(),
1057 make_vuln("Critical vuln", "critical", vec!["CVE-2024-0002"]),
1058 );
1059 let patch = make_patch(vulns, "desc");
1060 let info = convert_search_result_to_batch_info(patch);
1061 assert_eq!(info.severity, Some("critical".into()));
1062 }
1063
1064 #[test]
1065 fn test_convert_duplicate_cves_deduplicated() {
1066 let mut vulns = HashMap::new();
1067 vulns.insert(
1068 "GHSA-1111".into(),
1069 make_vuln("Vuln A", "high", vec!["CVE-2024-0001"]),
1070 );
1071 vulns.insert(
1072 "GHSA-2222".into(),
1073 make_vuln("Vuln B", "high", vec!["CVE-2024-0001"]),
1074 );
1075 let patch = make_patch(vulns, "desc");
1076 let info = convert_search_result_to_batch_info(patch);
1077 let cve_count = info.cve_ids.iter().filter(|c| *c == "CVE-2024-0001").count();
1079 assert_eq!(cve_count, 1);
1080 }
1081
1082 #[test]
1083 fn test_convert_title_truncated_at_100() {
1084 let long_summary = "x".repeat(150);
1085 let mut vulns = HashMap::new();
1086 vulns.insert(
1087 "GHSA-1111".into(),
1088 make_vuln(&long_summary, "high", vec![]),
1089 );
1090 let patch = make_patch(vulns, "desc");
1091 let info = convert_search_result_to_batch_info(patch);
1092 assert_eq!(info.title.len(), 100);
1094 assert!(info.title.ends_with("..."));
1095 }
1096
1097 #[test]
1098 fn test_convert_title_unicode_truncation() {
1099 let emoji_summary = "\u{1F600}".repeat(30);
1102 let mut vulns = HashMap::new();
1103 vulns.insert(
1104 "GHSA-1111".into(),
1105 make_vuln(&emoji_summary, "high", vec![]),
1106 );
1107 let patch = make_patch(vulns, "desc");
1108 let info = convert_search_result_to_batch_info(patch);
1110 assert!(!info.title.is_empty());
1111
1112 let patch2 = make_patch(HashMap::new(), &"\u{1F600}".repeat(120));
1114 let info2 = convert_search_result_to_batch_info(patch2);
1115 assert!(info2.title.ends_with("..."));
1116 }
1117
1118 #[test]
1119 fn test_convert_title_falls_back_to_description() {
1120 let mut vulns = HashMap::new();
1121 vulns.insert(
1122 "GHSA-1111".into(),
1123 make_vuln("", "high", vec![]),
1124 );
1125 let patch = make_patch(vulns, "Fallback desc");
1126 let info = convert_search_result_to_batch_info(patch);
1127 assert_eq!(info.title, "Fallback desc");
1128 }
1129
1130 #[test]
1131 fn test_convert_empty_summary_and_description() {
1132 let mut vulns = HashMap::new();
1133 vulns.insert(
1134 "GHSA-1111".into(),
1135 make_vuln("", "high", vec![]),
1136 );
1137 let patch = make_patch(vulns, "");
1138 let info = convert_search_result_to_batch_info(patch);
1139 assert!(info.title.is_empty());
1140 }
1141
1142 #[test]
1143 fn test_convert_cves_and_ghsas_sorted() {
1144 let mut vulns = HashMap::new();
1145 vulns.insert(
1146 "GHSA-cccc".into(),
1147 make_vuln("V1", "high", vec!["CVE-2024-0003"]),
1148 );
1149 vulns.insert(
1150 "GHSA-aaaa".into(),
1151 make_vuln("V2", "high", vec!["CVE-2024-0001"]),
1152 );
1153 vulns.insert(
1154 "GHSA-bbbb".into(),
1155 make_vuln("V3", "high", vec!["CVE-2024-0002"]),
1156 );
1157 let patch = make_patch(vulns, "desc");
1158 let info = convert_search_result_to_batch_info(patch);
1159 let mut sorted_cves = info.cve_ids.clone();
1161 sorted_cves.sort();
1162 assert_eq!(info.cve_ids, sorted_cves);
1163 let mut sorted_ghsas = info.ghsa_ids.clone();
1164 sorted_ghsas.sort();
1165 assert_eq!(info.ghsa_ids, sorted_ghsas);
1166 }
1167
1168 #[test]
1171 fn test_urlencoding_unicode() {
1172 let encoded = urlencoding_encode("café");
1174 assert_eq!(encoded, "caf%C3%A9");
1175 }
1176
1177 #[test]
1178 fn test_urlencoding_empty() {
1179 assert_eq!(urlencoding_encode(""), "");
1180 }
1181
1182 #[test]
1183 fn test_urlencoding_all_safe_chars() {
1184 let safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
1186 assert_eq!(urlencoding_encode(safe), safe);
1187 }
1188
1189 #[test]
1190 fn test_urlencoding_slash_and_at() {
1191 assert_eq!(urlencoding_encode("/"), "%2F");
1192 assert_eq!(urlencoding_encode("@"), "%40");
1193 }
1194
1195 #[test]
1196 fn test_sha256_uppercase_valid() {
1197 let upper = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789";
1198 assert!(is_valid_sha256_hex(upper));
1199 }
1200
1201 #[test]
1202 fn test_sha256_65_chars_invalid() {
1203 let too_long = "a".repeat(65);
1204 assert!(!is_valid_sha256_hex(&too_long));
1205 }
1206
1207 #[test]
1208 fn test_sha256_63_chars_invalid() {
1209 let too_short = "a".repeat(63);
1210 assert!(!is_valid_sha256_hex(&too_short));
1211 }
1212
1213 #[test]
1214 fn test_sha256_empty_invalid() {
1215 assert!(!is_valid_sha256_hex(""));
1216 }
1217
1218 #[test]
1219 fn test_sha256_mixed_case_valid() {
1220 let mixed = "aAbBcCdDeEfF0123456789aAbBcCdDeEfF0123456789aAbBcCdDeEfF01234567";
1221 assert_eq!(mixed.len(), 64);
1222 assert!(is_valid_sha256_hex(mixed));
1223 }
1224
1225 #[test]
1228 fn test_is_valid_uuid_accepts_standard_form() {
1229 assert!(is_valid_uuid("80630680-4da6-45f9-bba8-b888e0ffd58c"));
1230 assert!(is_valid_uuid("00000000-0000-0000-0000-000000000000"));
1231 assert!(is_valid_uuid("ABCDEF01-2345-6789-ABCD-EF0123456789"));
1233 }
1234
1235 #[test]
1236 fn test_is_valid_uuid_rejects_malformed() {
1237 assert!(!is_valid_uuid(""));
1238 assert!(!is_valid_uuid("not-a-uuid"));
1239 assert!(!is_valid_uuid("80630680-4da6-45f9-bba8"));
1241 assert!(!is_valid_uuid("8063068-4da6-45f9-bba8-b888e0ffd58c"));
1243 assert!(!is_valid_uuid("80630680-4da6-45f9-bba8-b888e0ffd58z"));
1245 assert!(!is_valid_uuid("80630680xxxxx"));
1247 }
1248
1249 #[tokio::test]
1257 async fn test_fetch_diff_rejects_invalid_uuid() {
1258 std::env::remove_var("SOCKET_API_TOKEN");
1259 let (client, _) = get_api_client_from_env(None).await;
1260 let result = client.fetch_diff("not-a-uuid").await;
1261 assert!(matches!(result, Err(ApiError::InvalidHash(_))));
1262 }
1263
1264 #[tokio::test]
1265 async fn test_fetch_package_rejects_invalid_uuid() {
1266 std::env::remove_var("SOCKET_API_TOKEN");
1267 let (client, _) = get_api_client_from_env(None).await;
1268 let result = client.fetch_package("xxx").await;
1269 assert!(matches!(result, Err(ApiError::InvalidHash(_))));
1270 }
1271
1272 #[test]
1275 fn validate_token_shape_accepts_canonical_api_token() {
1276 let raw = format!("sktsec_{}_api", "x".repeat(44));
1279 assert_eq!(raw.len(), 55);
1280 assert!(validate_token_shape(&raw).is_none());
1281 }
1282
1283 #[test]
1284 fn validate_token_shape_accepts_agent_token() {
1285 let raw = format!("sktsec_{}_agent", "x".repeat(44));
1286 assert!(validate_token_shape(&raw).is_none());
1287 }
1288
1289 #[test]
1290 fn validate_token_shape_flags_sha512_hash() {
1291 let hash = "sha512-7aegAloeNsCqF1mpNL2J9MJ2dpIxQEwgKvXPml8XY2rrV2Za+\
1292 bfj0yhG7RcqvqqLZ4iAH/drJjHjOqFkTGhddg==";
1293 let msg = validate_token_shape(hash).expect("hash must be flagged");
1294 assert!(
1295 msg.contains("does not look like a Socket API token"),
1296 "missing core warning; got: {msg}"
1297 );
1298 assert!(
1299 msg.contains("SRI-format hash"),
1300 "missing sha-hash hint; got: {msg}"
1301 );
1302 assert!(
1303 msg.contains("sktsec_"),
1304 "warning must point users at the correct prefix; got: {msg}"
1305 );
1306 assert!(
1308 !msg.contains("7RcqvqqLZ4iAH"),
1309 "middle of the value must be redacted; got: {msg}"
1310 );
1311 }
1312
1313 #[test]
1314 fn validate_token_shape_flags_too_short() {
1315 let msg = validate_token_shape("sktsec_abc_api")
1316 .expect("short token must be flagged");
1317 assert!(msg.contains("does not look like a Socket API token"));
1318 assert!(!msg.contains("SRI-format hash"));
1319 }
1320
1321 #[test]
1322 fn validate_token_shape_flags_missing_suffix() {
1323 let raw = format!("sktsec_{}", "x".repeat(50));
1324 assert!(validate_token_shape(&raw).is_some());
1325 }
1326
1327 #[test]
1328 fn looks_like_token_hash_recognizes_sri_prefixes() {
1329 assert!(looks_like_token_hash("sha256-abc"));
1330 assert!(looks_like_token_hash("sha384-abc"));
1331 assert!(looks_like_token_hash("sha512-abc"));
1332 assert!(!looks_like_token_hash("sktsec_xxx_api"));
1333 assert!(!looks_like_token_hash("hello"));
1334 assert!(!looks_like_token_hash(""));
1335 }
1336}