Skip to main content

socket_patch_core/api/
client.rs

1use std::collections::HashSet;
2
3use reqwest::header::{self, HeaderMap, HeaderValue};
4use reqwest::StatusCode;
5use serde::Serialize;
6
7use crate::api::types::*;
8use crate::constants::{
9    DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT as USER_AGENT_VALUE,
10};
11
12/// Check if debug mode is enabled via SOCKET_PATCH_DEBUG env.
13fn is_debug_enabled() -> bool {
14    match std::env::var("SOCKET_PATCH_DEBUG") {
15        Ok(val) => val == "1" || val == "true",
16        Err(_) => false,
17    }
18}
19
20/// Log debug messages when debug mode is enabled.
21fn debug_log(message: &str) {
22    if is_debug_enabled() {
23        eprintln!("[socket-patch debug] {}", message);
24    }
25}
26
27/// Severity order for sorting (most severe = lowest number).
28fn get_severity_order(severity: Option<&str>) -> u8 {
29    match severity.map(|s| s.to_lowercase()).as_deref() {
30        Some("critical") => 0,
31        Some("high") => 1,
32        Some("medium") => 2,
33        Some("low") => 3,
34        _ => 4,
35    }
36}
37
38/// Options for constructing an [`ApiClient`].
39#[derive(Debug, Clone)]
40pub struct ApiClientOptions {
41    pub api_url: String,
42    pub api_token: Option<String>,
43    /// When true, the client will use the public patch API proxy
44    /// which only provides access to free patches without authentication.
45    pub use_public_proxy: bool,
46    /// Organization slug for authenticated API access.
47    /// Required when using authenticated API (not public proxy).
48    pub org_slug: Option<String>,
49}
50
51/// HTTP client for the Socket Patch API.
52///
53/// Supports both the authenticated Socket API (`api.socket.dev`) and the
54/// public proxy (`patches-api.socket.dev`) which serves free patches
55/// without authentication.
56#[derive(Debug, Clone)]
57pub struct ApiClient {
58    client: reqwest::Client,
59    api_url: String,
60    api_token: Option<String>,
61    use_public_proxy: bool,
62    org_slug: Option<String>,
63}
64
65/// Body payload for the batch search POST endpoint.
66#[derive(Serialize)]
67struct BatchSearchBody {
68    components: Vec<BatchComponent>,
69}
70
71#[derive(Serialize)]
72struct BatchComponent {
73    purl: String,
74}
75
76impl ApiClient {
77    /// Create a new API client from the given options.
78    ///
79    /// Constructs a `reqwest::Client` with proper default headers
80    /// (User-Agent, Accept, and optionally Authorization).
81    pub fn new(options: ApiClientOptions) -> Self {
82        let api_url = options.api_url.trim_end_matches('/').to_string();
83
84        let mut default_headers = HeaderMap::new();
85        default_headers.insert(
86            header::USER_AGENT,
87            HeaderValue::from_static(USER_AGENT_VALUE),
88        );
89        default_headers.insert(
90            header::ACCEPT,
91            HeaderValue::from_static("application/json"),
92        );
93
94        if let Some(ref token) = options.api_token {
95            if let Ok(hv) = HeaderValue::from_str(&format!("Bearer {}", token)) {
96                default_headers.insert(header::AUTHORIZATION, hv);
97            }
98        }
99
100        let client = reqwest::Client::builder()
101            .default_headers(default_headers)
102            .build()
103            .expect("failed to build reqwest client");
104
105        Self {
106            client,
107            api_url,
108            api_token: options.api_token,
109            use_public_proxy: options.use_public_proxy,
110            org_slug: options.org_slug,
111        }
112    }
113
114    /// Returns the API token, if set.
115    pub fn api_token(&self) -> Option<&String> {
116        self.api_token.as_ref()
117    }
118
119    /// Returns the org slug, if set.
120    pub fn org_slug(&self) -> Option<&String> {
121        self.org_slug.as_ref()
122    }
123
124    // ── Internal helpers ──────────────────────────────────────────────
125
126    /// Internal GET that deserialises JSON. Returns `Ok(None)` on 404.
127    async fn get_json<T: serde::de::DeserializeOwned>(
128        &self,
129        path: &str,
130    ) -> Result<Option<T>, ApiError> {
131        let url = format!("{}{}", self.api_url, path);
132        debug_log(&format!("GET {}", url));
133
134        let resp = self
135            .client
136            .get(&url)
137            .send()
138            .await
139            .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
140
141        Self::handle_json_response(resp, self.use_public_proxy).await
142    }
143
144    /// Internal POST that deserialises JSON. Returns `Ok(None)` on 404.
145    async fn post_json<T: serde::de::DeserializeOwned, B: Serialize>(
146        &self,
147        path: &str,
148        body: &B,
149    ) -> Result<Option<T>, ApiError> {
150        let url = format!("{}{}", self.api_url, path);
151        debug_log(&format!("POST {}", url));
152
153        let resp = self
154            .client
155            .post(&url)
156            .header(header::CONTENT_TYPE, "application/json")
157            .json(body)
158            .send()
159            .await
160            .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
161
162        Self::handle_json_response(resp, self.use_public_proxy).await
163    }
164
165    /// Map an HTTP response to `Ok(Some(T))`, `Ok(None)` (404), or `Err`.
166    async fn handle_json_response<T: serde::de::DeserializeOwned>(
167        resp: reqwest::Response,
168        use_public_proxy: bool,
169    ) -> Result<Option<T>, ApiError> {
170        let status = resp.status();
171
172        match status {
173            StatusCode::OK => {
174                let body = resp
175                    .json::<T>()
176                    .await
177                    .map_err(|e| ApiError::Parse(format!("Failed to parse response: {}", e)))?;
178                Ok(Some(body))
179            }
180            StatusCode::NOT_FOUND => Ok(None),
181            StatusCode::UNAUTHORIZED => {
182                Err(ApiError::Unauthorized("Unauthorized: Invalid API token".into()))
183            }
184            StatusCode::FORBIDDEN => {
185                let msg = if use_public_proxy {
186                    "Forbidden: This patch is only available to paid subscribers. \
187                     Sign up at https://socket.dev to access paid patches."
188                } else {
189                    "Forbidden: Access denied. This may be a paid patch or \
190                     you may not have access to this organization."
191                };
192                Err(ApiError::Forbidden(msg.into()))
193            }
194            StatusCode::TOO_MANY_REQUESTS => {
195                Err(ApiError::RateLimited(
196                    "Rate limit exceeded. Please try again later.".into(),
197                ))
198            }
199            _ => {
200                let text = resp.text().await.unwrap_or_default();
201                Err(ApiError::Other(format!(
202                    "API request failed with status {}: {}",
203                    status.as_u16(),
204                    text
205                )))
206            }
207        }
208    }
209
210    // ── Public API methods ────────────────────────────────────────────
211
212    /// Fetch a patch by UUID (full details with blob content).
213    ///
214    /// Returns `Ok(None)` when the patch is not found (404).
215    pub async fn fetch_patch(
216        &self,
217        org_slug: Option<&str>,
218        uuid: &str,
219    ) -> Result<Option<PatchResponse>, ApiError> {
220        let path = if self.use_public_proxy {
221            format!("/patch/view/{}", uuid)
222        } else {
223            let slug = org_slug
224                .or(self.org_slug.as_deref())
225                .unwrap_or("default");
226            format!("/v0/orgs/{}/patches/view/{}", slug, uuid)
227        };
228        self.get_json(&path).await
229    }
230
231    /// Search patches by CVE ID.
232    pub async fn search_patches_by_cve(
233        &self,
234        org_slug: Option<&str>,
235        cve_id: &str,
236    ) -> Result<SearchResponse, ApiError> {
237        let encoded = urlencoding_encode(cve_id);
238        let path = if self.use_public_proxy {
239            format!("/patch/by-cve/{}", encoded)
240        } else {
241            let slug = org_slug
242                .or(self.org_slug.as_deref())
243                .unwrap_or("default");
244            format!("/v0/orgs/{}/patches/by-cve/{}", slug, encoded)
245        };
246        let result = self.get_json::<SearchResponse>(&path).await?;
247        Ok(result.unwrap_or_else(|| SearchResponse {
248            patches: Vec::new(),
249            can_access_paid_patches: false,
250        }))
251    }
252
253    /// Search patches by GHSA ID.
254    pub async fn search_patches_by_ghsa(
255        &self,
256        org_slug: Option<&str>,
257        ghsa_id: &str,
258    ) -> Result<SearchResponse, ApiError> {
259        let encoded = urlencoding_encode(ghsa_id);
260        let path = if self.use_public_proxy {
261            format!("/patch/by-ghsa/{}", encoded)
262        } else {
263            let slug = org_slug
264                .or(self.org_slug.as_deref())
265                .unwrap_or("default");
266            format!("/v0/orgs/{}/patches/by-ghsa/{}", slug, encoded)
267        };
268        let result = self.get_json::<SearchResponse>(&path).await?;
269        Ok(result.unwrap_or_else(|| SearchResponse {
270            patches: Vec::new(),
271            can_access_paid_patches: false,
272        }))
273    }
274
275    /// Search patches by package PURL.
276    ///
277    /// The PURL must be a valid Package URL starting with `pkg:`.
278    /// Examples: `pkg:npm/lodash@4.17.21`, `pkg:pypi/django@3.2.0`
279    pub async fn search_patches_by_package(
280        &self,
281        org_slug: Option<&str>,
282        purl: &str,
283    ) -> Result<SearchResponse, ApiError> {
284        let encoded = urlencoding_encode(purl);
285        let path = if self.use_public_proxy {
286            format!("/patch/by-package/{}", encoded)
287        } else {
288            let slug = org_slug
289                .or(self.org_slug.as_deref())
290                .unwrap_or("default");
291            format!("/v0/orgs/{}/patches/by-package/{}", slug, encoded)
292        };
293        let result = self.get_json::<SearchResponse>(&path).await?;
294        Ok(result.unwrap_or_else(|| SearchResponse {
295            patches: Vec::new(),
296            can_access_paid_patches: false,
297        }))
298    }
299
300    /// Search patches for multiple packages (batch).
301    ///
302    /// For authenticated API, uses the POST `/patches/batch` endpoint.
303    /// For the public proxy (which cannot cache POST bodies on CDN), falls
304    /// back to individual GET requests per PURL with a concurrency limit of
305    /// 10.
306    ///
307    /// Maximum 500 PURLs per request.
308    pub async fn search_patches_batch(
309        &self,
310        org_slug: Option<&str>,
311        purls: &[String],
312    ) -> Result<BatchSearchResponse, ApiError> {
313        if !self.use_public_proxy {
314            let slug = org_slug
315                .or(self.org_slug.as_deref())
316                .unwrap_or("default");
317            let path = format!("/v0/orgs/{}/patches/batch", slug);
318            let body = BatchSearchBody {
319                components: purls
320                    .iter()
321                    .map(|p| BatchComponent { purl: p.clone() })
322                    .collect(),
323            };
324            let result = self.post_json::<BatchSearchResponse, _>(&path, &body).await?;
325            return Ok(result.unwrap_or_else(|| BatchSearchResponse {
326                packages: Vec::new(),
327                can_access_paid_patches: false,
328            }));
329        }
330
331        // Public proxy: fall back to individual per-package GET requests
332        self.search_patches_batch_via_individual_queries(purls).await
333    }
334
335    /// Internal: fall back to individual GET requests per PURL when the
336    /// batch endpoint is not available (public proxy mode).
337    ///
338    /// Processes PURLs in batches of `CONCURRENCY_LIMIT` to avoid
339    /// overwhelming the server while remaining efficient.
340    async fn search_patches_batch_via_individual_queries(
341        &self,
342        purls: &[String],
343    ) -> Result<BatchSearchResponse, ApiError> {
344        const CONCURRENCY_LIMIT: usize = 10;
345
346        let mut packages: Vec<BatchPackagePatches> = Vec::new();
347        let mut can_access_paid_patches = false;
348
349        // Collect all (purl, response) pairs
350        let mut all_results: Vec<(String, Option<SearchResponse>)> = Vec::new();
351
352        for chunk in purls.chunks(CONCURRENCY_LIMIT) {
353            // Use tokio::JoinSet for concurrent execution within each chunk
354            let mut join_set = tokio::task::JoinSet::new();
355
356            for purl in chunk {
357                let purl = purl.clone();
358                let client = self.clone();
359                join_set.spawn(async move {
360                    let resp = client.search_patches_by_package(None, &purl).await;
361                    match resp {
362                        Ok(r) => (purl, Some(r)),
363                        Err(e) => {
364                            debug_log(&format!("Error fetching patches for {}: {}", purl, e));
365                            (purl, None)
366                        }
367                    }
368                });
369            }
370
371            while let Some(result) = join_set.join_next().await {
372                match result {
373                    Ok(pair) => all_results.push(pair),
374                    Err(e) => {
375                        debug_log(&format!("Task join error: {}", e));
376                    }
377                }
378            }
379        }
380
381        // Convert individual SearchResponse results to BatchSearchResponse format
382        for (purl, response) in all_results {
383            let response = match response {
384                Some(r) if !r.patches.is_empty() => r,
385                _ => continue,
386            };
387
388            if response.can_access_paid_patches {
389                can_access_paid_patches = true;
390            }
391
392            let batch_patches: Vec<BatchPatchInfo> = response
393                .patches
394                .into_iter()
395                .map(convert_search_result_to_batch_info)
396                .collect();
397
398            packages.push(BatchPackagePatches {
399                purl,
400                patches: batch_patches,
401            });
402        }
403
404        Ok(BatchSearchResponse {
405            packages,
406            can_access_paid_patches,
407        })
408    }
409
410    /// Fetch organizations accessible to the current API token.
411    pub async fn fetch_organizations(
412        &self,
413    ) -> Result<Vec<crate::api::types::OrganizationInfo>, ApiError> {
414        let path = "/v0/organizations";
415        match self
416            .get_json::<crate::api::types::OrganizationsResponse>(path)
417            .await?
418        {
419            Some(resp) => Ok(resp.organizations.into_values().collect()),
420            None => Ok(Vec::new()),
421        }
422    }
423
424    /// Resolve the org slug from the API token by querying `/v0/organizations`.
425    ///
426    /// If there is exactly one org, returns its slug.
427    /// If there are multiple, picks the first and prints a warning.
428    /// If there are none, returns an error.
429    pub async fn resolve_org_slug(&self) -> Result<String, ApiError> {
430        let orgs = self.fetch_organizations().await?;
431        match orgs.len() {
432            0 => Err(ApiError::Other(
433                "No organizations found for this API token.".into(),
434            )),
435            1 => Ok(orgs.into_iter().next().unwrap().slug),
436            _ => {
437                let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect();
438                let first = orgs[0].slug.clone();
439                eprintln!(
440                    "Multiple organizations found: {}. Using \"{}\". \
441                     Pass --org to select a different one.",
442                    slugs.join(", "),
443                    first
444                );
445                Ok(first)
446            }
447        }
448    }
449
450    /// Fetch a blob by its SHA-256 hash.
451    ///
452    /// Returns the raw binary content, or `Ok(None)` if not found.
453    /// Uses the authenticated endpoint when token and org slug are
454    /// available, otherwise falls back to the public proxy.
455    pub async fn fetch_blob(&self, hash: &str) -> Result<Option<Vec<u8>>, ApiError> {
456        // Validate hash format: SHA-256 = 64 hex characters
457        if !is_valid_sha256_hex(hash) {
458            return Err(ApiError::InvalidHash(format!(
459                "Invalid hash format: {}. Expected SHA256 hash (64 hex characters).",
460                hash
461            )));
462        }
463
464        let (url, use_auth) =
465            if self.api_token.is_some() && self.org_slug.is_some() && !self.use_public_proxy {
466                // Authenticated endpoint
467                let slug = self.org_slug.as_deref().unwrap();
468                let u = format!("{}/v0/orgs/{}/patches/blob/{}", self.api_url, slug, hash);
469                (u, true)
470            } else {
471                // Public proxy
472                let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL")
473                    .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string());
474                let u = format!("{}/patch/blob/{}", proxy_url.trim_end_matches('/'), hash);
475                (u, false)
476            };
477
478        debug_log(&format!("GET blob {}", url));
479
480        // Build the request. When fetching from the public proxy (different
481        // base URL than self.api_url), we use a plain client without auth
482        // headers to avoid leaking credentials to the proxy.
483        let resp = if use_auth {
484            self.client
485                .get(&url)
486                .header(header::ACCEPT, "application/octet-stream")
487                .send()
488                .await
489        } else {
490            let mut headers = HeaderMap::new();
491            headers.insert(
492                header::USER_AGENT,
493                HeaderValue::from_static(USER_AGENT_VALUE),
494            );
495            headers.insert(
496                header::ACCEPT,
497                HeaderValue::from_static("application/octet-stream"),
498            );
499
500            let plain_client = reqwest::Client::builder()
501                .default_headers(headers)
502                .build()
503                .expect("failed to build plain reqwest client");
504
505            plain_client.get(&url).send().await
506        };
507
508        let resp = resp.map_err(|e| {
509            ApiError::Network(format!("Network error fetching blob {}: {}", hash, e))
510        })?;
511
512        let status = resp.status();
513
514        match status {
515            StatusCode::OK => {
516                let bytes = resp.bytes().await.map_err(|e| {
517                    ApiError::Network(format!("Error reading blob body for {}: {}", hash, e))
518                })?;
519                Ok(Some(bytes.to_vec()))
520            }
521            StatusCode::NOT_FOUND => Ok(None),
522            _ => {
523                let text = resp.text().await.unwrap_or_default();
524                Err(ApiError::Other(format!(
525                    "Failed to fetch blob {}: status {} - {}",
526                    hash,
527                    status.as_u16(),
528                    text,
529                )))
530            }
531        }
532    }
533}
534
535// ── Free functions ────────────────────────────────────────────────────
536
537/// Get an API client configured from environment variables.
538///
539/// If `SOCKET_API_TOKEN` is not set, the client will use the public patch
540/// API proxy which provides free access to free-tier patches without
541/// authentication.
542///
543/// When `SOCKET_API_TOKEN` is set but no org slug is provided (neither via
544/// argument nor `SOCKET_ORG_SLUG` env var), the function will attempt to
545/// auto-resolve the org slug by querying `GET /v0/organizations`.
546///
547/// # Environment variables
548///
549/// | Variable | Purpose |
550/// |---|---|
551/// | `SOCKET_API_URL` | Override the API URL (default `https://api.socket.dev`) |
552/// | `SOCKET_API_TOKEN` | API token for authenticated access |
553/// | `SOCKET_PATCH_PROXY_URL` | Override the public proxy URL (default `https://patches-api.socket.dev`) |
554/// | `SOCKET_ORG_SLUG` | Organization slug |
555///
556/// Returns `(client, use_public_proxy)`.
557pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
558    let api_token = std::env::var("SOCKET_API_TOKEN")
559        .ok()
560        .filter(|t| !t.is_empty());
561    let resolved_org_slug = org_slug
562        .map(String::from)
563        .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok());
564
565    if api_token.is_none() {
566        let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL")
567            .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string());
568        eprintln!(
569            "No SOCKET_API_TOKEN set. Using public patch API proxy (free patches only)."
570        );
571        let client = ApiClient::new(ApiClientOptions {
572            api_url: proxy_url,
573            api_token: None,
574            use_public_proxy: true,
575            org_slug: None,
576        });
577        return (client, true);
578    }
579
580    let api_url =
581        std::env::var("SOCKET_API_URL").unwrap_or_else(|_| DEFAULT_SOCKET_API_URL.to_string());
582
583    // Auto-resolve org slug if not provided
584    let final_org_slug = if resolved_org_slug.is_some() {
585        resolved_org_slug
586    } else {
587        let temp_client = ApiClient::new(ApiClientOptions {
588            api_url: api_url.clone(),
589            api_token: api_token.clone(),
590            use_public_proxy: false,
591            org_slug: None,
592        });
593        match temp_client.resolve_org_slug().await {
594            Ok(slug) => Some(slug),
595            Err(e) => {
596                eprintln!("Warning: Could not auto-detect organization: {e}");
597                None
598            }
599        }
600    };
601
602    let client = ApiClient::new(ApiClientOptions {
603        api_url,
604        api_token,
605        use_public_proxy: false,
606        org_slug: final_org_slug,
607    });
608    (client, false)
609}
610
611// ── Helpers ───────────────────────────────────────────────────────────
612
613/// Percent-encode a string for use in URL path segments.
614fn urlencoding_encode(input: &str) -> String {
615    // Encode everything that is not unreserved per RFC 3986.
616    let mut out = String::with_capacity(input.len());
617    for byte in input.bytes() {
618        match byte {
619            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
620                out.push(byte as char)
621            }
622            _ => {
623                out.push('%');
624                out.push_str(&format!("{:02X}", byte));
625            }
626        }
627    }
628    out
629}
630
631/// Truncate a string to at most `max_chars` characters, appending "..." if truncated.
632/// Unlike byte slicing (`&s[..n]`), this is safe for multi-byte UTF-8 characters.
633fn truncate_to_chars(s: &str, max_chars: usize) -> String {
634    if s.chars().count() <= max_chars {
635        return s.to_string();
636    }
637    let truncated: String = s.chars().take(max_chars).collect();
638    format!("{}...", truncated)
639}
640
641/// Validate that a string is a 64-character hex string (SHA-256).
642fn is_valid_sha256_hex(s: &str) -> bool {
643    s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit())
644}
645
646/// Convert a `PatchSearchResult` into a `BatchPatchInfo`, extracting
647/// CVE/GHSA IDs and computing the highest severity.
648fn convert_search_result_to_batch_info(patch: PatchSearchResult) -> BatchPatchInfo {
649    let mut cve_ids: Vec<String> = Vec::new();
650    let mut ghsa_ids: Vec<String> = Vec::new();
651    let mut highest_severity: Option<String> = None;
652    let mut title = String::new();
653
654    let mut seen_cves: HashSet<String> = HashSet::new();
655
656    for (ghsa_id, vuln) in &patch.vulnerabilities {
657        ghsa_ids.push(ghsa_id.clone());
658
659        for cve in &vuln.cves {
660            if seen_cves.insert(cve.clone()) {
661                cve_ids.push(cve.clone());
662            }
663        }
664
665        // Track highest severity (lower order number = higher severity)
666        let current_order = get_severity_order(highest_severity.as_deref());
667        let vuln_order = get_severity_order(Some(&vuln.severity));
668        if vuln_order < current_order {
669            highest_severity = Some(vuln.severity.clone());
670        }
671
672        // Use first non-empty summary as title
673        if title.is_empty() && !vuln.summary.is_empty() {
674            title = truncate_to_chars(&vuln.summary, 97);
675        }
676    }
677
678    // Use description as fallback title
679    if title.is_empty() && !patch.description.is_empty() {
680        title = truncate_to_chars(&patch.description, 97);
681    }
682
683    cve_ids.sort();
684    ghsa_ids.sort();
685
686    BatchPatchInfo {
687        uuid: patch.uuid,
688        purl: patch.purl,
689        tier: patch.tier,
690        cve_ids,
691        ghsa_ids,
692        severity: highest_severity,
693        title,
694    }
695}
696
697// ── Error type ────────────────────────────────────────────────────────
698
699/// Errors returned by [`ApiClient`] methods.
700#[derive(Debug, thiserror::Error)]
701pub enum ApiError {
702    #[error("{0}")]
703    Network(String),
704
705    #[error("{0}")]
706    Parse(String),
707
708    #[error("{0}")]
709    Unauthorized(String),
710
711    #[error("{0}")]
712    Forbidden(String),
713
714    #[error("{0}")]
715    RateLimited(String),
716
717    #[error("{0}")]
718    InvalidHash(String),
719
720    #[error("{0}")]
721    Other(String),
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727    use std::collections::HashMap;
728
729    #[test]
730    fn test_urlencoding_basic() {
731        assert_eq!(urlencoding_encode("hello"), "hello");
732        assert_eq!(urlencoding_encode("a b"), "a%20b");
733        assert_eq!(
734            urlencoding_encode("pkg:npm/lodash@4.17.21"),
735            "pkg%3Anpm%2Flodash%404.17.21"
736        );
737    }
738
739    #[test]
740    fn test_is_valid_sha256_hex() {
741        let valid = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
742        assert!(is_valid_sha256_hex(valid));
743
744        // Too short
745        assert!(!is_valid_sha256_hex("abcdef"));
746        // Non-hex
747        assert!(!is_valid_sha256_hex(
748            "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"
749        ));
750    }
751
752    #[test]
753    fn test_severity_order() {
754        assert!(get_severity_order(Some("critical")) < get_severity_order(Some("high")));
755        assert!(get_severity_order(Some("high")) < get_severity_order(Some("medium")));
756        assert!(get_severity_order(Some("medium")) < get_severity_order(Some("low")));
757        assert!(get_severity_order(Some("low")) < get_severity_order(None));
758        assert_eq!(get_severity_order(Some("unknown")), get_severity_order(None));
759    }
760
761    #[test]
762    fn test_convert_search_result_to_batch_info() {
763        let mut vulns = HashMap::new();
764        vulns.insert(
765            "GHSA-1234-5678-9abc".to_string(),
766            VulnerabilityResponse {
767                cves: vec!["CVE-2024-0001".into()],
768                summary: "Test vulnerability".into(),
769                severity: "high".into(),
770                description: "A test vuln".into(),
771            },
772        );
773
774        let patch = PatchSearchResult {
775            uuid: "uuid-1".into(),
776            purl: "pkg:npm/test@1.0.0".into(),
777            published_at: "2024-01-01".into(),
778            description: "A patch".into(),
779            license: "MIT".into(),
780            tier: "free".into(),
781            vulnerabilities: vulns,
782        };
783
784        let info = convert_search_result_to_batch_info(patch);
785        assert_eq!(info.uuid, "uuid-1");
786        assert_eq!(info.cve_ids, vec!["CVE-2024-0001"]);
787        assert_eq!(info.ghsa_ids, vec!["GHSA-1234-5678-9abc"]);
788        assert_eq!(info.severity, Some("high".into()));
789        assert_eq!(info.title, "Test vulnerability");
790    }
791
792    #[tokio::test]
793    async fn test_get_api_client_from_env_no_token() {
794        // Clear token to ensure public proxy mode
795        std::env::remove_var("SOCKET_API_TOKEN");
796        let (client, is_public) = get_api_client_from_env(None).await;
797        assert!(is_public);
798        assert!(client.use_public_proxy);
799    }
800
801    // ── Group 6: convert_search_result_to_batch_info edge cases ──────
802
803    fn make_vuln(summary: &str, severity: &str, cves: Vec<&str>) -> VulnerabilityResponse {
804        VulnerabilityResponse {
805            cves: cves.into_iter().map(String::from).collect(),
806            summary: summary.into(),
807            severity: severity.into(),
808            description: "desc".into(),
809        }
810    }
811
812    fn make_patch(
813        vulns: HashMap<String, VulnerabilityResponse>,
814        description: &str,
815    ) -> PatchSearchResult {
816        PatchSearchResult {
817            uuid: "uuid-1".into(),
818            purl: "pkg:npm/test@1.0.0".into(),
819            published_at: "2024-01-01".into(),
820            description: description.into(),
821            license: "MIT".into(),
822            tier: "free".into(),
823            vulnerabilities: vulns,
824        }
825    }
826
827    #[test]
828    fn test_convert_no_vulnerabilities() {
829        let patch = make_patch(HashMap::new(), "A patch description");
830        let info = convert_search_result_to_batch_info(patch);
831        assert!(info.cve_ids.is_empty());
832        assert!(info.ghsa_ids.is_empty());
833        assert_eq!(info.title, "A patch description");
834        assert!(info.severity.is_none());
835    }
836
837    #[test]
838    fn test_convert_multiple_vulns_picks_highest_severity() {
839        let mut vulns = HashMap::new();
840        vulns.insert(
841            "GHSA-1111".into(),
842            make_vuln("Medium vuln", "medium", vec!["CVE-2024-0001"]),
843        );
844        vulns.insert(
845            "GHSA-2222".into(),
846            make_vuln("Critical vuln", "critical", vec!["CVE-2024-0002"]),
847        );
848        let patch = make_patch(vulns, "desc");
849        let info = convert_search_result_to_batch_info(patch);
850        assert_eq!(info.severity, Some("critical".into()));
851    }
852
853    #[test]
854    fn test_convert_duplicate_cves_deduplicated() {
855        let mut vulns = HashMap::new();
856        vulns.insert(
857            "GHSA-1111".into(),
858            make_vuln("Vuln A", "high", vec!["CVE-2024-0001"]),
859        );
860        vulns.insert(
861            "GHSA-2222".into(),
862            make_vuln("Vuln B", "high", vec!["CVE-2024-0001"]),
863        );
864        let patch = make_patch(vulns, "desc");
865        let info = convert_search_result_to_batch_info(patch);
866        // Same CVE in both vulns should only appear once
867        let cve_count = info.cve_ids.iter().filter(|c| *c == "CVE-2024-0001").count();
868        assert_eq!(cve_count, 1);
869    }
870
871    #[test]
872    fn test_convert_title_truncated_at_100() {
873        let long_summary = "x".repeat(150);
874        let mut vulns = HashMap::new();
875        vulns.insert(
876            "GHSA-1111".into(),
877            make_vuln(&long_summary, "high", vec![]),
878        );
879        let patch = make_patch(vulns, "desc");
880        let info = convert_search_result_to_batch_info(patch);
881        // Should be 97 chars + "..." = 100 chars
882        assert_eq!(info.title.len(), 100);
883        assert!(info.title.ends_with("..."));
884    }
885
886    #[test]
887    fn test_convert_title_unicode_truncation() {
888        // Create a summary with multi-byte chars that would panic with byte slicing
889        // Each emoji is 4 bytes, so 30 emojis = 120 bytes but only 30 chars
890        let emoji_summary = "\u{1F600}".repeat(30);
891        let mut vulns = HashMap::new();
892        vulns.insert(
893            "GHSA-1111".into(),
894            make_vuln(&emoji_summary, "high", vec![]),
895        );
896        let patch = make_patch(vulns, "desc");
897        // This should NOT panic (validates the UTF-8 truncation fix)
898        let info = convert_search_result_to_batch_info(patch);
899        assert!(!info.title.is_empty());
900
901        // Also test with description fallback
902        let patch2 = make_patch(HashMap::new(), &"\u{1F600}".repeat(120));
903        let info2 = convert_search_result_to_batch_info(patch2);
904        assert!(info2.title.ends_with("..."));
905    }
906
907    #[test]
908    fn test_convert_title_falls_back_to_description() {
909        let mut vulns = HashMap::new();
910        vulns.insert(
911            "GHSA-1111".into(),
912            make_vuln("", "high", vec![]),
913        );
914        let patch = make_patch(vulns, "Fallback desc");
915        let info = convert_search_result_to_batch_info(patch);
916        assert_eq!(info.title, "Fallback desc");
917    }
918
919    #[test]
920    fn test_convert_empty_summary_and_description() {
921        let mut vulns = HashMap::new();
922        vulns.insert(
923            "GHSA-1111".into(),
924            make_vuln("", "high", vec![]),
925        );
926        let patch = make_patch(vulns, "");
927        let info = convert_search_result_to_batch_info(patch);
928        assert!(info.title.is_empty());
929    }
930
931    #[test]
932    fn test_convert_cves_and_ghsas_sorted() {
933        let mut vulns = HashMap::new();
934        vulns.insert(
935            "GHSA-cccc".into(),
936            make_vuln("V1", "high", vec!["CVE-2024-0003"]),
937        );
938        vulns.insert(
939            "GHSA-aaaa".into(),
940            make_vuln("V2", "high", vec!["CVE-2024-0001"]),
941        );
942        vulns.insert(
943            "GHSA-bbbb".into(),
944            make_vuln("V3", "high", vec!["CVE-2024-0002"]),
945        );
946        let patch = make_patch(vulns, "desc");
947        let info = convert_search_result_to_batch_info(patch);
948        // Both should be sorted alphabetically
949        let mut sorted_cves = info.cve_ids.clone();
950        sorted_cves.sort();
951        assert_eq!(info.cve_ids, sorted_cves);
952        let mut sorted_ghsas = info.ghsa_ids.clone();
953        sorted_ghsas.sort();
954        assert_eq!(info.ghsa_ids, sorted_ghsas);
955    }
956
957    // ── Group 7: urlencoding + SHA256 edge cases ─────────────────────
958
959    #[test]
960    fn test_urlencoding_unicode() {
961        // Multi-byte UTF-8: 'é' = 0xC3 0xA9
962        let encoded = urlencoding_encode("café");
963        assert_eq!(encoded, "caf%C3%A9");
964    }
965
966    #[test]
967    fn test_urlencoding_empty() {
968        assert_eq!(urlencoding_encode(""), "");
969    }
970
971    #[test]
972    fn test_urlencoding_all_safe_chars() {
973        // Unreserved chars should pass through
974        let safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
975        assert_eq!(urlencoding_encode(safe), safe);
976    }
977
978    #[test]
979    fn test_urlencoding_slash_and_at() {
980        assert_eq!(urlencoding_encode("/"), "%2F");
981        assert_eq!(urlencoding_encode("@"), "%40");
982    }
983
984    #[test]
985    fn test_sha256_uppercase_valid() {
986        let upper = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789";
987        assert!(is_valid_sha256_hex(upper));
988    }
989
990    #[test]
991    fn test_sha256_65_chars_invalid() {
992        let too_long = "a".repeat(65);
993        assert!(!is_valid_sha256_hex(&too_long));
994    }
995
996    #[test]
997    fn test_sha256_63_chars_invalid() {
998        let too_short = "a".repeat(63);
999        assert!(!is_valid_sha256_hex(&too_short));
1000    }
1001
1002    #[test]
1003    fn test_sha256_empty_invalid() {
1004        assert!(!is_valid_sha256_hex(""));
1005    }
1006
1007    #[test]
1008    fn test_sha256_mixed_case_valid() {
1009        let mixed = "aAbBcCdDeEfF0123456789aAbBcCdDeEfF0123456789aAbBcCdDeEfF01234567";
1010        assert_eq!(mixed.len(), 64);
1011        assert!(is_valid_sha256_hex(mixed));
1012    }
1013}