Skip to main content

socket_patch_core/api/
client.rs

1use std::collections::HashSet;
2
3use reqwest::header::{self, HeaderMap, HeaderValue};
4use reqwest::StatusCode;
5use serde::Serialize;
6
7use crate::api::types::*;
8use crate::constants::{
9    DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT as USER_AGENT_VALUE,
10};
11use crate::utils::env_compat::read_env_with_legacy;
12
13/// Check if debug mode is enabled via SOCKET_DEBUG env (falling back to the
14/// legacy SOCKET_PATCH_DEBUG name with a one-shot deprecation warning).
15fn is_debug_enabled() -> bool {
16    match read_env_with_legacy("SOCKET_DEBUG", "SOCKET_PATCH_DEBUG") {
17        Some(val) => val == "1" || val == "true",
18        None => false,
19    }
20}
21
22/// Log debug messages when debug mode is enabled.
23fn debug_log(message: &str) {
24    if is_debug_enabled() {
25        eprintln!("[socket-patch debug] {}", message);
26    }
27}
28
29/// Severity order for sorting (most severe = lowest number).
30fn get_severity_order(severity: Option<&str>) -> u8 {
31    match severity.map(|s| s.to_lowercase()).as_deref() {
32        Some("critical") => 0,
33        Some("high") => 1,
34        Some("medium") => 2,
35        Some("low") => 3,
36        _ => 4,
37    }
38}
39
40/// Options for constructing an [`ApiClient`].
41#[derive(Debug, Clone)]
42pub struct ApiClientOptions {
43    pub api_url: String,
44    pub api_token: Option<String>,
45    /// When true, the client will use the public patch API proxy
46    /// which only provides access to free patches without authentication.
47    pub use_public_proxy: bool,
48    /// Organization slug for authenticated API access.
49    /// Required when using authenticated API (not public proxy).
50    pub org_slug: Option<String>,
51}
52
53/// HTTP client for the Socket Patch API.
54///
55/// Supports both the authenticated Socket API (`api.socket.dev`) and the
56/// public proxy (`patches-api.socket.dev`) which serves free patches
57/// without authentication.
58#[derive(Debug, Clone)]
59pub struct ApiClient {
60    client: reqwest::Client,
61    api_url: String,
62    api_token: Option<String>,
63    use_public_proxy: bool,
64    org_slug: Option<String>,
65}
66
67/// Body payload for the batch search POST endpoint.
68#[derive(Serialize)]
69struct BatchSearchBody {
70    components: Vec<BatchComponent>,
71}
72
73#[derive(Serialize)]
74struct BatchComponent {
75    purl: String,
76}
77
78impl ApiClient {
79    /// Create a new API client from the given options.
80    ///
81    /// Constructs a `reqwest::Client` with proper default headers
82    /// (User-Agent, Accept, and optionally Authorization).
83    pub fn new(options: ApiClientOptions) -> Self {
84        let api_url = options.api_url.trim_end_matches('/').to_string();
85
86        let mut default_headers = HeaderMap::new();
87        default_headers.insert(
88            header::USER_AGENT,
89            HeaderValue::from_static(USER_AGENT_VALUE),
90        );
91        default_headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
92
93        if let Some(ref token) = options.api_token {
94            if let Ok(hv) = HeaderValue::from_str(&format!("Bearer {}", token)) {
95                default_headers.insert(header::AUTHORIZATION, hv);
96            }
97        }
98
99        let client = reqwest::Client::builder()
100            .default_headers(default_headers)
101            .build()
102            .expect("failed to build reqwest client");
103
104        Self {
105            client,
106            api_url,
107            api_token: options.api_token,
108            use_public_proxy: options.use_public_proxy,
109            org_slug: options.org_slug,
110        }
111    }
112
113    /// Returns the API token, if set.
114    pub fn api_token(&self) -> Option<&String> {
115        self.api_token.as_ref()
116    }
117
118    /// Returns the org slug, if set.
119    pub fn org_slug(&self) -> Option<&String> {
120        self.org_slug.as_ref()
121    }
122
123    // ── Internal helpers ──────────────────────────────────────────────
124
125    /// Internal GET that deserialises JSON. Returns `Ok(None)` on 404.
126    async fn get_json<T: serde::de::DeserializeOwned>(
127        &self,
128        path: &str,
129    ) -> Result<Option<T>, ApiError> {
130        let url = format!("{}{}", self.api_url, path);
131        debug_log(&format!("GET {}", url));
132
133        let resp = self
134            .client
135            .get(&url)
136            .send()
137            .await
138            .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
139
140        Self::handle_json_response(resp, self.use_public_proxy).await
141    }
142
143    /// Internal POST that deserialises JSON. Returns `Ok(None)` on 404.
144    async fn post_json<T: serde::de::DeserializeOwned, B: Serialize>(
145        &self,
146        path: &str,
147        body: &B,
148    ) -> Result<Option<T>, ApiError> {
149        let url = format!("{}{}", self.api_url, path);
150        debug_log(&format!("POST {}", url));
151
152        let resp = self
153            .client
154            .post(&url)
155            .header(header::CONTENT_TYPE, "application/json")
156            .json(body)
157            .send()
158            .await
159            .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
160
161        Self::handle_json_response(resp, self.use_public_proxy).await
162    }
163
164    /// Map an HTTP response to `Ok(Some(T))`, `Ok(None)` (404), or `Err`.
165    async fn handle_json_response<T: serde::de::DeserializeOwned>(
166        resp: reqwest::Response,
167        use_public_proxy: bool,
168    ) -> Result<Option<T>, ApiError> {
169        let status = resp.status();
170
171        match status {
172            StatusCode::OK => {
173                let body = resp
174                    .json::<T>()
175                    .await
176                    .map_err(|e| ApiError::Parse(format!("Failed to parse response: {}", e)))?;
177                Ok(Some(body))
178            }
179            StatusCode::NOT_FOUND => Ok(None),
180            StatusCode::UNAUTHORIZED => Err(ApiError::Unauthorized(
181                "Unauthorized: Invalid API token".into(),
182            )),
183            StatusCode::FORBIDDEN => {
184                let msg = if use_public_proxy {
185                    "Forbidden: This patch is only available to paid subscribers. \
186                     Sign up at https://socket.dev to access paid patches."
187                } else {
188                    "Forbidden: Access denied. This may be a paid patch or \
189                     you may not have access to this organization."
190                };
191                Err(ApiError::Forbidden(msg.into()))
192            }
193            StatusCode::TOO_MANY_REQUESTS => Err(ApiError::RateLimited(
194                "Rate limit exceeded. Please try again later.".into(),
195            )),
196            _ => {
197                let text = resp.text().await.unwrap_or_default();
198                Err(ApiError::Other(format!(
199                    "API request failed with status {}: {}",
200                    status.as_u16(),
201                    text
202                )))
203            }
204        }
205    }
206
207    // ── Public API methods ────────────────────────────────────────────
208
209    /// Fetch a patch by UUID (full details with blob content).
210    ///
211    /// Returns `Ok(None)` when the patch is not found (404).
212    pub async fn fetch_patch(
213        &self,
214        org_slug: Option<&str>,
215        uuid: &str,
216    ) -> Result<Option<PatchResponse>, ApiError> {
217        let path = if self.use_public_proxy {
218            format!("/patch/view/{}", uuid)
219        } else {
220            let slug = org_slug.or(self.org_slug.as_deref()).unwrap_or("default");
221            format!("/v0/orgs/{}/patches/view/{}", slug, uuid)
222        };
223        self.get_json(&path).await
224    }
225
226    /// Shared implementation for `search_patches_by_{cve,ghsa,package}`.
227    /// `route` is the `by-<x>` URL segment — the rest of the path layout
228    /// is identical across the three endpoints.
229    async fn search_patches_by_route(
230        &self,
231        org_slug: Option<&str>,
232        route: &str,
233        identifier: &str,
234    ) -> Result<SearchResponse, ApiError> {
235        let encoded = urlencoding_encode(identifier);
236        let path = if self.use_public_proxy {
237            format!("/patch/{route}/{encoded}")
238        } else {
239            let slug = org_slug.or(self.org_slug.as_deref()).unwrap_or("default");
240            format!("/v0/orgs/{slug}/patches/{route}/{encoded}")
241        };
242        let result = self.get_json::<SearchResponse>(&path).await?;
243        Ok(result.unwrap_or_else(|| SearchResponse {
244            patches: Vec::new(),
245            can_access_paid_patches: false,
246        }))
247    }
248
249    /// Search patches by CVE ID.
250    pub async fn search_patches_by_cve(
251        &self,
252        org_slug: Option<&str>,
253        cve_id: &str,
254    ) -> Result<SearchResponse, ApiError> {
255        self.search_patches_by_route(org_slug, "by-cve", cve_id)
256            .await
257    }
258
259    /// Search patches by GHSA ID.
260    pub async fn search_patches_by_ghsa(
261        &self,
262        org_slug: Option<&str>,
263        ghsa_id: &str,
264    ) -> Result<SearchResponse, ApiError> {
265        self.search_patches_by_route(org_slug, "by-ghsa", ghsa_id)
266            .await
267    }
268
269    /// Search patches by package PURL.
270    ///
271    /// The PURL must be a valid Package URL starting with `pkg:`.
272    /// Examples: `pkg:npm/lodash@4.17.21`, `pkg:pypi/django@3.2.0`
273    pub async fn search_patches_by_package(
274        &self,
275        org_slug: Option<&str>,
276        purl: &str,
277    ) -> Result<SearchResponse, ApiError> {
278        self.search_patches_by_route(org_slug, "by-package", purl)
279            .await
280    }
281
282    /// Search patches for multiple packages (batch).
283    ///
284    /// For authenticated API, uses the POST `/patches/batch` endpoint.
285    /// For the public proxy (which cannot cache POST bodies on CDN), falls
286    /// back to individual GET requests per PURL with a concurrency limit of
287    /// 10.
288    ///
289    /// Maximum 500 PURLs per request.
290    pub async fn search_patches_batch(
291        &self,
292        org_slug: Option<&str>,
293        purls: &[String],
294    ) -> Result<BatchSearchResponse, ApiError> {
295        if !self.use_public_proxy {
296            let slug = org_slug.or(self.org_slug.as_deref()).unwrap_or("default");
297            let path = format!("/v0/orgs/{}/patches/batch", slug);
298            let body = BatchSearchBody {
299                components: purls
300                    .iter()
301                    .map(|p| BatchComponent { purl: p.clone() })
302                    .collect(),
303            };
304            let result = self
305                .post_json::<BatchSearchResponse, _>(&path, &body)
306                .await?;
307            return Ok(result.unwrap_or_else(|| BatchSearchResponse {
308                packages: Vec::new(),
309                can_access_paid_patches: false,
310            }));
311        }
312
313        // Public proxy: fall back to individual per-package GET requests
314        self.search_patches_batch_via_individual_queries(purls)
315            .await
316    }
317
318    /// Internal: fall back to individual GET requests per PURL when the
319    /// batch endpoint is not available (public proxy mode).
320    ///
321    /// Processes PURLs in batches of `CONCURRENCY_LIMIT` to avoid
322    /// overwhelming the server while remaining efficient.
323    async fn search_patches_batch_via_individual_queries(
324        &self,
325        purls: &[String],
326    ) -> Result<BatchSearchResponse, ApiError> {
327        const CONCURRENCY_LIMIT: usize = 10;
328
329        // Collect all (purl, response) pairs
330        let mut all_results: Vec<(String, Option<SearchResponse>)> = Vec::new();
331
332        for chunk in purls.chunks(CONCURRENCY_LIMIT) {
333            // Use tokio::JoinSet for concurrent execution within each chunk
334            let mut join_set = tokio::task::JoinSet::new();
335
336            for purl in chunk {
337                let purl = purl.clone();
338                let client = self.clone();
339                join_set.spawn(async move {
340                    let resp = client.search_patches_by_package(None, &purl).await;
341                    match resp {
342                        Ok(r) => (purl, Some(r)),
343                        Err(e) => {
344                            debug_log(&format!("Error fetching patches for {}: {}", purl, e));
345                            (purl, None)
346                        }
347                    }
348                });
349            }
350
351            while let Some(result) = join_set.join_next().await {
352                match result {
353                    Ok(pair) => all_results.push(pair),
354                    Err(e) => {
355                        debug_log(&format!("Task join error: {}", e));
356                    }
357                }
358            }
359        }
360
361        // Convert the individual SearchResponse results into the batch shape.
362        Ok(assemble_batch_from_individual(all_results))
363    }
364
365    /// Fetch organizations accessible to the current API token.
366    pub async fn fetch_organizations(
367        &self,
368    ) -> Result<Vec<crate::api::types::OrganizationInfo>, ApiError> {
369        let path = "/v0/organizations";
370        match self
371            .get_json::<crate::api::types::OrganizationsResponse>(path)
372            .await?
373        {
374            Some(resp) => Ok(resp.organizations.into_values().collect()),
375            None => Ok(Vec::new()),
376        }
377    }
378
379    /// Resolve the org slug from the API token by querying `/v0/organizations`.
380    ///
381    /// If there is exactly one org, returns its slug.
382    /// If there are multiple, picks the first and prints a warning.
383    /// If there are none, returns an error.
384    pub async fn resolve_org_slug(&self) -> Result<String, ApiError> {
385        let orgs = self.fetch_organizations().await?;
386        select_org_slug(orgs)
387    }
388
389    /// Fetch a blob by its SHA-256 hash.
390    ///
391    /// Returns the raw binary content, or `Ok(None)` if not found.
392    /// Uses the authenticated endpoint when token and org slug are
393    /// available, otherwise falls back to the public proxy.
394    pub async fn fetch_blob(&self, hash: &str) -> Result<Option<Vec<u8>>, ApiError> {
395        // Validate hash format: SHA-256 = 64 hex characters
396        if !is_valid_sha256_hex(hash) {
397            return Err(ApiError::InvalidHash(format!(
398                "Invalid hash format: {}. Expected SHA256 hash (64 hex characters).",
399                hash
400            )));
401        }
402        self.fetch_binary("blob", "blob", hash).await
403    }
404
405    /// Fetch a per-file diff archive (tar.gz of bsdiff deltas) by patch UUID.
406    ///
407    /// Returns the raw archive bytes, or `Ok(None)` if not found (404). The
408    /// public proxy serves these under `/patch/diff/<uuid>`; the
409    /// authenticated API serves them under `/v0/orgs/<slug>/patches/diff/<uuid>`.
410    pub async fn fetch_diff(&self, uuid: &str) -> Result<Option<Vec<u8>>, ApiError> {
411        if !is_valid_uuid(uuid) {
412            return Err(ApiError::InvalidHash(format!(
413                "Invalid patch UUID: {}",
414                uuid
415            )));
416        }
417        self.fetch_binary("diff", "diff", uuid).await
418    }
419
420    /// Fetch a per-package patch archive (tar.gz of patched files) by patch UUID.
421    ///
422    /// Returns the raw archive bytes, or `Ok(None)` if not found (404).
423    pub async fn fetch_package(&self, uuid: &str) -> Result<Option<Vec<u8>>, ApiError> {
424        if !is_valid_uuid(uuid) {
425            return Err(ApiError::InvalidHash(format!(
426                "Invalid patch UUID: {}",
427                uuid
428            )));
429        }
430        self.fetch_binary("package", "package", uuid).await
431    }
432
433    /// Build the URL (and an `is_authenticated` flag) for a binary fetch of
434    /// `kind` (`blob` / `diff` / `package`) identified by `identifier`.
435    ///
436    /// Uses the authenticated `/v0/orgs/<slug>/patches/...` endpoint when a
437    /// token and org slug are configured (and we're not pinned to the public
438    /// proxy). Otherwise it targets the public proxy.
439    ///
440    /// In public-proxy mode the base is the client's own configured `api_url`
441    /// — the same value the JSON endpoints (`get_json`/`post_json`) use — so an
442    /// explicit `--proxy-url` / `SOCKET_PROXY_URL` override is honored for
443    /// binary downloads too. Only when falling back from an *authenticated*
444    /// client that lacks an org slug (so `api_url` is the auth host, not a
445    /// proxy) do we re-derive the proxy base from the environment.
446    fn binary_url(&self, kind: &str, identifier: &str) -> (String, bool) {
447        if self.api_token.is_some() && self.org_slug.is_some() && !self.use_public_proxy {
448            let slug = self.org_slug.as_deref().unwrap();
449            let u = format!(
450                "{}/v0/orgs/{}/patches/{}/{}",
451                self.api_url, slug, kind, identifier
452            );
453            (u, true)
454        } else {
455            let base = if self.use_public_proxy {
456                self.api_url.clone()
457            } else {
458                read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
459                    .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
460            };
461            let u = format!(
462                "{}/patch/{}/{}",
463                base.trim_end_matches('/'),
464                kind,
465                identifier
466            );
467            (u, false)
468        }
469    }
470
471    /// Shared implementation for `fetch_blob` / `fetch_diff` / `fetch_package`.
472    ///
473    /// `kind` is the URL segment (`blob` / `diff` / `package`). `label` is the
474    /// human-readable noun used in log + error messages. `identifier` is the
475    /// hash or UUID interpolated into the URL.
476    async fn fetch_binary(
477        &self,
478        kind: &str,
479        label: &str,
480        identifier: &str,
481    ) -> Result<Option<Vec<u8>>, ApiError> {
482        let (url, use_auth) = self.binary_url(kind, identifier);
483
484        debug_log(&format!("GET {} {}", label, url));
485
486        // Build the request. When fetching from the public proxy (different
487        // base URL than self.api_url), we use a plain client without auth
488        // headers to avoid leaking credentials to the proxy.
489        let resp = if use_auth {
490            self.client
491                .get(&url)
492                .header(header::ACCEPT, "application/octet-stream")
493                .send()
494                .await
495        } else {
496            let mut headers = HeaderMap::new();
497            headers.insert(
498                header::USER_AGENT,
499                HeaderValue::from_static(USER_AGENT_VALUE),
500            );
501            headers.insert(
502                header::ACCEPT,
503                HeaderValue::from_static("application/octet-stream"),
504            );
505
506            let plain_client = reqwest::Client::builder()
507                .default_headers(headers)
508                .build()
509                .expect("failed to build plain reqwest client");
510
511            plain_client.get(&url).send().await
512        };
513
514        let resp = resp.map_err(|e| {
515            ApiError::Network(format!(
516                "Network error fetching {} {}: {}",
517                label, identifier, e
518            ))
519        })?;
520
521        let status = resp.status();
522
523        match status {
524            StatusCode::OK => {
525                let bytes = resp.bytes().await.map_err(|e| {
526                    ApiError::Network(format!(
527                        "Error reading {} body for {}: {}",
528                        label, identifier, e
529                    ))
530                })?;
531                Ok(Some(bytes.to_vec()))
532            }
533            StatusCode::NOT_FOUND => Ok(None),
534            _ => {
535                let text = resp.text().await.unwrap_or_default();
536                Err(ApiError::Other(format!(
537                    "Failed to fetch {} {}: status {} - {}",
538                    label,
539                    identifier,
540                    status.as_u16(),
541                    text,
542                )))
543            }
544        }
545    }
546}
547
548// ── Free functions ────────────────────────────────────────────────────
549
550/// Explicit overrides for environment-based API client construction.
551///
552/// Each `Some(value)` wins over the corresponding env var; `None` falls
553/// back to env-var lookup (with the legacy `SOCKET_PATCH_*` shim where
554/// applicable).
555#[derive(Debug, Clone, Default)]
556pub struct ApiClientEnvOverrides {
557    pub api_url: Option<String>,
558    pub api_token: Option<String>,
559    pub org_slug: Option<String>,
560    pub proxy_url: Option<String>,
561}
562
563/// Get an API client configured from environment variables.
564///
565/// If `SOCKET_API_TOKEN` is not set, the client will use the public patch
566/// API proxy which provides free access to free-tier patches without
567/// authentication.
568///
569/// When `SOCKET_API_TOKEN` is set but no org slug is provided (neither via
570/// argument nor `SOCKET_ORG_SLUG` env var), the function will attempt to
571/// auto-resolve the org slug by querying `GET /v0/organizations`.
572///
573/// # Environment variables
574///
575/// | Variable | Purpose |
576/// |---|---|
577/// | `SOCKET_API_URL` | Override the API URL (default `https://api.socket.dev`) |
578/// | `SOCKET_API_TOKEN` | API token for authenticated access |
579/// | `SOCKET_PROXY_URL` | Override the public proxy URL (default `https://patches-api.socket.dev`). Legacy: `SOCKET_PATCH_PROXY_URL`. |
580/// | `SOCKET_ORG_SLUG` | Organization slug |
581///
582/// Returns `(client, use_public_proxy)`.
583pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
584    get_api_client_with_overrides(ApiClientEnvOverrides {
585        org_slug: org_slug.map(String::from),
586        ..ApiClientEnvOverrides::default()
587    })
588    .await
589}
590
591/// Like [`get_api_client_from_env`] but with explicit overrides for every
592/// env-driven knob. Each `Some(value)` in `overrides` wins over the
593/// corresponding env var. Used by CLI commands that expose `--api-url`,
594/// `--api-token`, `--org`, `--proxy-url` flags via [`crate::utils`] in the
595/// CLI crate.
596pub async fn get_api_client_with_overrides(overrides: ApiClientEnvOverrides) -> (ApiClient, bool) {
597    let api_token = overrides
598        .api_token
599        .or_else(|| std::env::var("SOCKET_API_TOKEN").ok())
600        .filter(|t| !t.is_empty());
601    let resolved_org_slug = overrides
602        .org_slug
603        .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok());
604
605    if api_token.is_none() {
606        let proxy_url = overrides.proxy_url.unwrap_or_else(|| {
607            read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
608                .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
609        });
610        eprintln!("No SOCKET_API_TOKEN set. Using public patch API proxy (free patches only).");
611        let client = ApiClient::new(ApiClientOptions {
612            api_url: proxy_url,
613            api_token: None,
614            use_public_proxy: true,
615            org_slug: None,
616        });
617        return (client, true);
618    }
619
620    // Shape check the configured token before the network round-trip so
621    // a "you set the hash, not the token" mistake is loud and immediate.
622    if let Some(ref t) = api_token {
623        if let Some(msg) = validate_token_shape(t) {
624            eprintln!("{msg}");
625        }
626    }
627
628    let api_url = overrides
629        .api_url
630        .or_else(|| std::env::var("SOCKET_API_URL").ok())
631        .unwrap_or_else(|| DEFAULT_SOCKET_API_URL.to_string());
632
633    // Auto-resolve org slug if not provided
634    let final_org_slug = if resolved_org_slug.is_some() {
635        resolved_org_slug
636    } else {
637        let temp_client = ApiClient::new(ApiClientOptions {
638            api_url: api_url.clone(),
639            api_token: api_token.clone(),
640            use_public_proxy: false,
641            org_slug: None,
642        });
643        match temp_client.resolve_org_slug().await {
644            Ok(slug) => Some(slug),
645            Err(e) => {
646                eprintln!("Warning: Could not auto-detect organization: {e}");
647                if matches!(e, ApiError::Unauthorized(_)) {
648                    if let Some(ref t) = api_token {
649                        if looks_like_token_hash(t) {
650                            eprintln!(
651                                "  Hint: SOCKET_API_TOKEN starts with `{}-` \
652                                 which is the stored hash format. Set it to \
653                                 the raw `sktsec_..._api` value instead.",
654                                t.split('-').next().unwrap_or("sha512")
655                            );
656                        }
657                    }
658                }
659                None
660            }
661        }
662    };
663
664    let client = ApiClient::new(ApiClientOptions {
665        api_url,
666        api_token,
667        use_public_proxy: false,
668        org_slug: final_org_slug,
669    });
670    (client, false)
671}
672
673/// Build a public-proxy `ApiClient` from the same overrides used by
674/// [`get_api_client_with_overrides`], ignoring any API token.
675///
676/// Used by `scan` and `get` to retry against the public proxy after
677/// the authenticated endpoint returns 401/403 — a stale/revoked token
678/// shouldn't block access to free patches. The auth header is
679/// deliberately dropped (`api_token: None`).
680pub fn build_proxy_fallback_client(overrides: &ApiClientEnvOverrides) -> ApiClient {
681    let proxy_url = overrides.proxy_url.clone().unwrap_or_else(|| {
682        read_env_with_legacy("SOCKET_PROXY_URL", "SOCKET_PATCH_PROXY_URL")
683            .unwrap_or_else(|| DEFAULT_PATCH_API_PROXY_URL.to_string())
684    });
685    ApiClient::new(ApiClientOptions {
686        api_url: proxy_url,
687        api_token: None,
688        use_public_proxy: true,
689        org_slug: None,
690    })
691}
692
693/// Return `true` when the configured token value looks like an
694/// SRI-format hash (`sha512-<base64>` etc.) rather than a raw API
695/// token. The server stores tokens *as* this hash; the CLI sometimes
696/// gets configured with the storage representation by mistake (users
697/// copy what they see in the dashboard). Surfacing this as a hint
698/// short-circuits a confusing 401 round-trip.
699pub fn looks_like_token_hash(token: &str) -> bool {
700    matches!(
701        token.split_once('-'),
702        Some(("sha256" | "sha384" | "sha512", _))
703    )
704}
705
706/// Inspect a configured `SOCKET_API_TOKEN` value and return a
707/// human-readable warning when the value doesn't match the canonical
708/// Socket API token shape (`sktsec_<44 chars>_api`). Returns `None`
709/// when the token looks valid, so the caller can ignore the result
710/// without checking length.
711///
712/// The validation is intentionally a non-authoritative shape check —
713/// the server's regex is the source of truth. We only flag values
714/// that are *obviously* wrong (e.g. the storage hash, an empty
715/// prefix/suffix) so a benign typo at the server's regex boundary
716/// doesn't generate noise.
717///
718/// The returned message redacts the middle of the token (first 8 +
719/// last 4 chars) so a real token doesn't leak into stderr if a user
720/// pastes one with a wrong suffix.
721pub fn validate_token_shape(token: &str) -> Option<String> {
722    let has_prefix = token.starts_with("sktsec_");
723    let has_suffix = token.ends_with("_api") || token.ends_with("_agent");
724    let plausible_len = token.len() >= 55;
725    if has_prefix && has_suffix && plausible_len {
726        return None;
727    }
728    let len = token.len();
729    let head: String = token.chars().take(8).collect();
730    let tail_start = len.saturating_sub(4);
731    let tail: String = token.chars().skip(tail_start).collect();
732    let preview = if len <= 12 {
733        token.to_string()
734    } else {
735        format!("{head}...{tail}")
736    };
737    let hash_hint = if looks_like_token_hash(token) {
738        "\n  That value looks like an SRI-format hash (sha###-<base64>) — \
739         the server stores the *hash* of your token, not what you should \
740         set here. Use the raw `sktsec_..._api` value shown when the token \
741         was generated."
742    } else {
743        ""
744    };
745    Some(format!(
746        "Warning: SOCKET_API_TOKEN does not look like a Socket API token \
747         (expected `sktsec_<44 chars>_api`).{hash_hint}\n  \
748         Got: {preview} ({len} chars). Continuing anyway; the server may \
749         reject this with 401."
750    ))
751}
752
753/// Classify an [`ApiError`] as a candidate for the auth → proxy
754/// fallback. We only re-route on 401/403 (the stale-credentials
755/// signals). Network errors, rate limits, 404s, and 5xx surface as-is
756/// so they remain visible to the operator.
757pub fn is_fallback_candidate(err: &ApiError) -> bool {
758    matches!(err, ApiError::Unauthorized(_) | ApiError::Forbidden(_))
759}
760
761/// Choose an org slug from the list returned by `/v0/organizations`.
762///
763/// Returns an error when the list is empty, the sole slug when there is
764/// exactly one, and the first slug (with a warning) when there are several.
765///
766/// `fetch_organizations` collects from a `HashMap`, so the upstream order is
767/// not stable across runs. We sort by slug first so the chosen org *and* the
768/// warning text are deterministic — otherwise a token with multiple orgs
769/// could silently operate against a different org on each invocation.
770fn select_org_slug(mut orgs: Vec<crate::api::types::OrganizationInfo>) -> Result<String, ApiError> {
771    orgs.sort_by(|a, b| a.slug.cmp(&b.slug));
772    match orgs.len() {
773        0 => Err(ApiError::Other(
774            "No organizations found for this API token.".into(),
775        )),
776        1 => Ok(orgs.into_iter().next().unwrap().slug),
777        _ => {
778            let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect();
779            let first = orgs[0].slug.clone();
780            eprintln!(
781                "Multiple organizations found: {}. Using \"{}\". \
782                 Pass --org to select a different one.",
783                slugs.join(", "),
784                first
785            );
786            Ok(first)
787        }
788    }
789}
790
791// ── Helpers ───────────────────────────────────────────────────────────
792
793/// Percent-encode a string for use in URL path segments.
794fn urlencoding_encode(input: &str) -> String {
795    // Encode everything that is not unreserved per RFC 3986.
796    let mut out = String::with_capacity(input.len());
797    for byte in input.bytes() {
798        match byte {
799            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
800                out.push(byte as char)
801            }
802            _ => {
803                out.push('%');
804                out.push_str(&format!("{:02X}", byte));
805            }
806        }
807    }
808    out
809}
810
811/// Truncate a string to at most `max_chars` characters, appending "..." if truncated.
812/// Unlike byte slicing (`&s[..n]`), this is safe for multi-byte UTF-8 characters.
813fn truncate_to_chars(s: &str, max_chars: usize) -> String {
814    if s.chars().count() <= max_chars {
815        return s.to_string();
816    }
817    let truncated: String = s.chars().take(max_chars).collect();
818    format!("{}...", truncated)
819}
820
821/// Validate that a string is a 64-character hex string (SHA-256).
822fn is_valid_sha256_hex(s: &str) -> bool {
823    s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit())
824}
825
826/// Validate the standard 8-4-4-4-12 UUID hex grouping.
827fn is_valid_uuid(s: &str) -> bool {
828    let parts: Vec<&str> = s.split('-').collect();
829    if parts.len() != 5 {
830        return false;
831    }
832    let lengths = [8, 4, 4, 4, 12];
833    parts
834        .iter()
835        .zip(lengths.iter())
836        .all(|(part, &want)| part.len() == want && part.bytes().all(|b| b.is_ascii_hexdigit()))
837}
838
839/// Convert a `PatchSearchResult` into a `BatchPatchInfo`, extracting
840/// CVE/GHSA IDs and computing the highest severity.
841fn convert_search_result_to_batch_info(patch: PatchSearchResult) -> BatchPatchInfo {
842    let mut cve_ids: Vec<String> = Vec::new();
843    let mut ghsa_ids: Vec<String> = Vec::new();
844    let mut highest_severity: Option<String> = None;
845    let mut title = String::new();
846
847    let mut seen_cves: HashSet<String> = HashSet::new();
848
849    // `vulnerabilities` is a HashMap, so iterate in a stable (GHSA-id) order.
850    // Otherwise the chosen `title` (first non-empty summary) — and the
851    // first-seen tie-break for equal severities — would vary across runs.
852    let mut entries: Vec<(&String, &VulnerabilityResponse)> =
853        patch.vulnerabilities.iter().collect();
854    entries.sort_by(|a, b| a.0.cmp(b.0));
855
856    for (ghsa_id, vuln) in entries {
857        ghsa_ids.push(ghsa_id.clone());
858
859        for cve in &vuln.cves {
860            if seen_cves.insert(cve.clone()) {
861                cve_ids.push(cve.clone());
862            }
863        }
864
865        // Track highest severity (lower order number = higher severity)
866        let current_order = get_severity_order(highest_severity.as_deref());
867        let vuln_order = get_severity_order(Some(&vuln.severity));
868        if vuln_order < current_order {
869            highest_severity = Some(vuln.severity.clone());
870        }
871
872        // Use first non-empty summary as title
873        if title.is_empty() && !vuln.summary.is_empty() {
874            title = truncate_to_chars(&vuln.summary, 97);
875        }
876    }
877
878    // Use description as fallback title
879    if title.is_empty() && !patch.description.is_empty() {
880        title = truncate_to_chars(&patch.description, 97);
881    }
882
883    cve_ids.sort();
884    ghsa_ids.sort();
885
886    BatchPatchInfo {
887        uuid: patch.uuid,
888        purl: patch.purl,
889        tier: patch.tier,
890        cve_ids,
891        ghsa_ids,
892        severity: highest_severity,
893        title,
894    }
895}
896
897/// Assemble a [`BatchSearchResponse`] from the per-PURL [`SearchResponse`]s
898/// gathered by the public-proxy fallback (one GET per package).
899///
900/// A `None` entry is a query that errored and is skipped. The
901/// `can_access_paid_patches` capability is OR-aggregated across **every**
902/// successful response — independent of whether that response carried any
903/// patches — because it is a global capability signal, not a per-package
904/// one. The empty-patches check only governs whether a package is added to
905/// the `packages` list (an empty package would be noise), so it must run
906/// *after* the flag is observed; folding it into the same skip would drop a
907/// `canAccessPaidPatches: true` that arrived alongside an empty patch list.
908fn assemble_batch_from_individual(
909    results: Vec<(String, Option<SearchResponse>)>,
910) -> BatchSearchResponse {
911    let mut packages: Vec<BatchPackagePatches> = Vec::new();
912    let mut can_access_paid_patches = false;
913
914    for (purl, response) in results {
915        let Some(response) = response else { continue };
916
917        if response.can_access_paid_patches {
918            can_access_paid_patches = true;
919        }
920
921        if response.patches.is_empty() {
922            continue;
923        }
924
925        let batch_patches: Vec<BatchPatchInfo> = response
926            .patches
927            .into_iter()
928            .map(convert_search_result_to_batch_info)
929            .collect();
930
931        packages.push(BatchPackagePatches {
932            purl,
933            patches: batch_patches,
934        });
935    }
936
937    BatchSearchResponse {
938        packages,
939        can_access_paid_patches,
940    }
941}
942
943// ── Error type ────────────────────────────────────────────────────────
944
945/// Errors returned by [`ApiClient`] methods.
946#[derive(Debug, thiserror::Error)]
947pub enum ApiError {
948    #[error("{0}")]
949    Network(String),
950
951    #[error("{0}")]
952    Parse(String),
953
954    #[error("{0}")]
955    Unauthorized(String),
956
957    #[error("{0}")]
958    Forbidden(String),
959
960    #[error("{0}")]
961    RateLimited(String),
962
963    #[error("{0}")]
964    InvalidHash(String),
965
966    #[error("{0}")]
967    Other(String),
968}
969
970#[cfg(test)]
971mod tests {
972    use super::*;
973    use std::collections::HashMap;
974
975    #[test]
976    fn test_urlencoding_basic() {
977        assert_eq!(urlencoding_encode("hello"), "hello");
978        assert_eq!(urlencoding_encode("a b"), "a%20b");
979        assert_eq!(
980            urlencoding_encode("pkg:npm/lodash@4.17.21"),
981            "pkg%3Anpm%2Flodash%404.17.21"
982        );
983    }
984
985    #[test]
986    fn test_is_valid_sha256_hex() {
987        let valid = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
988        assert!(is_valid_sha256_hex(valid));
989
990        // Too short
991        assert!(!is_valid_sha256_hex("abcdef"));
992        // Non-hex
993        assert!(!is_valid_sha256_hex(
994            "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"
995        ));
996    }
997
998    #[test]
999    fn test_severity_order() {
1000        assert!(get_severity_order(Some("critical")) < get_severity_order(Some("high")));
1001        assert!(get_severity_order(Some("high")) < get_severity_order(Some("medium")));
1002        assert!(get_severity_order(Some("medium")) < get_severity_order(Some("low")));
1003        assert!(get_severity_order(Some("low")) < get_severity_order(None));
1004        assert_eq!(
1005            get_severity_order(Some("unknown")),
1006            get_severity_order(None)
1007        );
1008    }
1009
1010    #[test]
1011    fn test_convert_search_result_to_batch_info() {
1012        let mut vulns = HashMap::new();
1013        vulns.insert(
1014            "GHSA-1234-5678-9abc".to_string(),
1015            VulnerabilityResponse {
1016                cves: vec!["CVE-2024-0001".into()],
1017                summary: "Test vulnerability".into(),
1018                severity: "high".into(),
1019                description: "A test vuln".into(),
1020            },
1021        );
1022
1023        let patch = PatchSearchResult {
1024            uuid: "uuid-1".into(),
1025            purl: "pkg:npm/test@1.0.0".into(),
1026            published_at: "2024-01-01".into(),
1027            description: "A patch".into(),
1028            license: "MIT".into(),
1029            tier: "free".into(),
1030            vulnerabilities: vulns,
1031        };
1032
1033        let info = convert_search_result_to_batch_info(patch);
1034        assert_eq!(info.uuid, "uuid-1");
1035        assert_eq!(info.cve_ids, vec!["CVE-2024-0001"]);
1036        assert_eq!(info.ghsa_ids, vec!["GHSA-1234-5678-9abc"]);
1037        assert_eq!(info.severity, Some("high".into()));
1038        assert_eq!(info.title, "Test vulnerability");
1039    }
1040
1041    #[tokio::test]
1042    async fn test_get_api_client_from_env_no_token() {
1043        // Clear token to ensure public proxy mode
1044        std::env::remove_var("SOCKET_API_TOKEN");
1045        let (client, is_public) = get_api_client_from_env(None).await;
1046        assert!(is_public);
1047        assert!(client.use_public_proxy);
1048    }
1049
1050    // ── Group 6: convert_search_result_to_batch_info edge cases ──────
1051
1052    fn make_vuln(summary: &str, severity: &str, cves: Vec<&str>) -> VulnerabilityResponse {
1053        VulnerabilityResponse {
1054            cves: cves.into_iter().map(String::from).collect(),
1055            summary: summary.into(),
1056            severity: severity.into(),
1057            description: "desc".into(),
1058        }
1059    }
1060
1061    fn make_patch(
1062        vulns: HashMap<String, VulnerabilityResponse>,
1063        description: &str,
1064    ) -> PatchSearchResult {
1065        PatchSearchResult {
1066            uuid: "uuid-1".into(),
1067            purl: "pkg:npm/test@1.0.0".into(),
1068            published_at: "2024-01-01".into(),
1069            description: description.into(),
1070            license: "MIT".into(),
1071            tier: "free".into(),
1072            vulnerabilities: vulns,
1073        }
1074    }
1075
1076    #[test]
1077    fn test_convert_no_vulnerabilities() {
1078        let patch = make_patch(HashMap::new(), "A patch description");
1079        let info = convert_search_result_to_batch_info(patch);
1080        assert!(info.cve_ids.is_empty());
1081        assert!(info.ghsa_ids.is_empty());
1082        assert_eq!(info.title, "A patch description");
1083        assert!(info.severity.is_none());
1084    }
1085
1086    #[test]
1087    fn test_convert_multiple_vulns_picks_highest_severity() {
1088        let mut vulns = HashMap::new();
1089        vulns.insert(
1090            "GHSA-1111".into(),
1091            make_vuln("Medium vuln", "medium", vec!["CVE-2024-0001"]),
1092        );
1093        vulns.insert(
1094            "GHSA-2222".into(),
1095            make_vuln("Critical vuln", "critical", vec!["CVE-2024-0002"]),
1096        );
1097        let patch = make_patch(vulns, "desc");
1098        let info = convert_search_result_to_batch_info(patch);
1099        assert_eq!(info.severity, Some("critical".into()));
1100    }
1101
1102    #[test]
1103    fn test_convert_duplicate_cves_deduplicated() {
1104        let mut vulns = HashMap::new();
1105        vulns.insert(
1106            "GHSA-1111".into(),
1107            make_vuln("Vuln A", "high", vec!["CVE-2024-0001"]),
1108        );
1109        vulns.insert(
1110            "GHSA-2222".into(),
1111            make_vuln("Vuln B", "high", vec!["CVE-2024-0001"]),
1112        );
1113        let patch = make_patch(vulns, "desc");
1114        let info = convert_search_result_to_batch_info(patch);
1115        // Same CVE in both vulns should only appear once
1116        let cve_count = info
1117            .cve_ids
1118            .iter()
1119            .filter(|c| *c == "CVE-2024-0001")
1120            .count();
1121        assert_eq!(cve_count, 1);
1122    }
1123
1124    #[test]
1125    fn test_convert_title_truncated_at_100() {
1126        let long_summary = "x".repeat(150);
1127        let mut vulns = HashMap::new();
1128        vulns.insert("GHSA-1111".into(), make_vuln(&long_summary, "high", vec![]));
1129        let patch = make_patch(vulns, "desc");
1130        let info = convert_search_result_to_batch_info(patch);
1131        // Should be 97 chars + "..." = 100 chars
1132        assert_eq!(info.title.len(), 100);
1133        assert!(info.title.ends_with("..."));
1134    }
1135
1136    #[test]
1137    fn test_convert_title_unicode_truncation() {
1138        // Create a summary with multi-byte chars that would panic with byte slicing
1139        // Each emoji is 4 bytes, so 30 emojis = 120 bytes but only 30 chars
1140        let emoji_summary = "\u{1F600}".repeat(30);
1141        let mut vulns = HashMap::new();
1142        vulns.insert(
1143            "GHSA-1111".into(),
1144            make_vuln(&emoji_summary, "high", vec![]),
1145        );
1146        let patch = make_patch(vulns, "desc");
1147        // This should NOT panic (validates the UTF-8 truncation fix)
1148        let info = convert_search_result_to_batch_info(patch);
1149        assert!(!info.title.is_empty());
1150
1151        // Also test with description fallback
1152        let patch2 = make_patch(HashMap::new(), &"\u{1F600}".repeat(120));
1153        let info2 = convert_search_result_to_batch_info(patch2);
1154        assert!(info2.title.ends_with("..."));
1155    }
1156
1157    #[test]
1158    fn test_convert_title_falls_back_to_description() {
1159        let mut vulns = HashMap::new();
1160        vulns.insert("GHSA-1111".into(), make_vuln("", "high", vec![]));
1161        let patch = make_patch(vulns, "Fallback desc");
1162        let info = convert_search_result_to_batch_info(patch);
1163        assert_eq!(info.title, "Fallback desc");
1164    }
1165
1166    #[test]
1167    fn test_convert_empty_summary_and_description() {
1168        let mut vulns = HashMap::new();
1169        vulns.insert("GHSA-1111".into(), make_vuln("", "high", vec![]));
1170        let patch = make_patch(vulns, "");
1171        let info = convert_search_result_to_batch_info(patch);
1172        assert!(info.title.is_empty());
1173    }
1174
1175    #[test]
1176    fn test_convert_cves_and_ghsas_sorted() {
1177        let mut vulns = HashMap::new();
1178        vulns.insert(
1179            "GHSA-cccc".into(),
1180            make_vuln("V1", "high", vec!["CVE-2024-0003"]),
1181        );
1182        vulns.insert(
1183            "GHSA-aaaa".into(),
1184            make_vuln("V2", "high", vec!["CVE-2024-0001"]),
1185        );
1186        vulns.insert(
1187            "GHSA-bbbb".into(),
1188            make_vuln("V3", "high", vec!["CVE-2024-0002"]),
1189        );
1190        let patch = make_patch(vulns, "desc");
1191        let info = convert_search_result_to_batch_info(patch);
1192        // Both should be sorted alphabetically
1193        let mut sorted_cves = info.cve_ids.clone();
1194        sorted_cves.sort();
1195        assert_eq!(info.cve_ids, sorted_cves);
1196        let mut sorted_ghsas = info.ghsa_ids.clone();
1197        sorted_ghsas.sort();
1198        assert_eq!(info.ghsa_ids, sorted_ghsas);
1199    }
1200
1201    // ── Group 7: urlencoding + SHA256 edge cases ─────────────────────
1202
1203    #[test]
1204    fn test_urlencoding_unicode() {
1205        // Multi-byte UTF-8: 'é' = 0xC3 0xA9
1206        let encoded = urlencoding_encode("café");
1207        assert_eq!(encoded, "caf%C3%A9");
1208    }
1209
1210    #[test]
1211    fn test_urlencoding_empty() {
1212        assert_eq!(urlencoding_encode(""), "");
1213    }
1214
1215    #[test]
1216    fn test_urlencoding_all_safe_chars() {
1217        // Unreserved chars should pass through
1218        let safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
1219        assert_eq!(urlencoding_encode(safe), safe);
1220    }
1221
1222    #[test]
1223    fn test_urlencoding_slash_and_at() {
1224        assert_eq!(urlencoding_encode("/"), "%2F");
1225        assert_eq!(urlencoding_encode("@"), "%40");
1226    }
1227
1228    #[test]
1229    fn test_sha256_uppercase_valid() {
1230        let upper = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789";
1231        assert!(is_valid_sha256_hex(upper));
1232    }
1233
1234    #[test]
1235    fn test_sha256_65_chars_invalid() {
1236        let too_long = "a".repeat(65);
1237        assert!(!is_valid_sha256_hex(&too_long));
1238    }
1239
1240    #[test]
1241    fn test_sha256_63_chars_invalid() {
1242        let too_short = "a".repeat(63);
1243        assert!(!is_valid_sha256_hex(&too_short));
1244    }
1245
1246    #[test]
1247    fn test_sha256_empty_invalid() {
1248        assert!(!is_valid_sha256_hex(""));
1249    }
1250
1251    #[test]
1252    fn test_sha256_mixed_case_valid() {
1253        let mixed = "aAbBcCdDeEfF0123456789aAbBcCdDeEfF0123456789aAbBcCdDeEfF01234567";
1254        assert_eq!(mixed.len(), 64);
1255        assert!(is_valid_sha256_hex(mixed));
1256    }
1257
1258    // ── UUID validation tests ───────────────────────────────────────
1259
1260    #[test]
1261    fn test_is_valid_uuid_accepts_standard_form() {
1262        assert!(is_valid_uuid("80630680-4da6-45f9-bba8-b888e0ffd58c"));
1263        assert!(is_valid_uuid("00000000-0000-0000-0000-000000000000"));
1264        // Uppercase hex is acceptable.
1265        assert!(is_valid_uuid("ABCDEF01-2345-6789-ABCD-EF0123456789"));
1266    }
1267
1268    #[test]
1269    fn test_is_valid_uuid_rejects_malformed() {
1270        assert!(!is_valid_uuid(""));
1271        assert!(!is_valid_uuid("not-a-uuid"));
1272        // Wrong segment count.
1273        assert!(!is_valid_uuid("80630680-4da6-45f9-bba8"));
1274        // Wrong length on first segment.
1275        assert!(!is_valid_uuid("8063068-4da6-45f9-bba8-b888e0ffd58c"));
1276        // Non-hex character.
1277        assert!(!is_valid_uuid("80630680-4da6-45f9-bba8-b888e0ffd58z"));
1278        // No dashes.
1279        assert!(!is_valid_uuid("80630680xxxxx"));
1280    }
1281
1282    // ── fetch_diff / fetch_package validation tests ─────────────────
1283    //
1284    // These tests cover input validation only — they intentionally do
1285    // NOT hit the network. The shared `fetch_binary` helper handles the
1286    // transport, and `fetch_blob` already has integration coverage via
1287    // the e2e_npm test.
1288
1289    #[tokio::test]
1290    async fn test_fetch_diff_rejects_invalid_uuid() {
1291        std::env::remove_var("SOCKET_API_TOKEN");
1292        let (client, _) = get_api_client_from_env(None).await;
1293        let result = client.fetch_diff("not-a-uuid").await;
1294        assert!(matches!(result, Err(ApiError::InvalidHash(_))));
1295    }
1296
1297    #[tokio::test]
1298    async fn test_fetch_package_rejects_invalid_uuid() {
1299        std::env::remove_var("SOCKET_API_TOKEN");
1300        let (client, _) = get_api_client_from_env(None).await;
1301        let result = client.fetch_package("xxx").await;
1302        assert!(matches!(result, Err(ApiError::InvalidHash(_))));
1303    }
1304
1305    // ── Token shape validation ─────────────────────────────────────────
1306
1307    #[test]
1308    fn validate_token_shape_accepts_canonical_api_token() {
1309        // 7-char prefix + 44 random chars + 4-char `_api` suffix = 55 chars,
1310        // matching the server's SOCKET_TOKEN_REGEXP.
1311        let raw = format!("sktsec_{}_api", "x".repeat(44));
1312        assert_eq!(raw.len(), 55);
1313        assert!(validate_token_shape(&raw).is_none());
1314    }
1315
1316    #[test]
1317    fn validate_token_shape_accepts_agent_token() {
1318        let raw = format!("sktsec_{}_agent", "x".repeat(44));
1319        assert!(validate_token_shape(&raw).is_none());
1320    }
1321
1322    #[test]
1323    fn validate_token_shape_flags_sha512_hash() {
1324        let hash = "sha512-7aegAloeNsCqF1mpNL2J9MJ2dpIxQEwgKvXPml8XY2rrV2Za+\
1325                    bfj0yhG7RcqvqqLZ4iAH/drJjHjOqFkTGhddg==";
1326        let msg = validate_token_shape(hash).expect("hash must be flagged");
1327        assert!(
1328            msg.contains("does not look like a Socket API token"),
1329            "missing core warning; got: {msg}"
1330        );
1331        assert!(
1332            msg.contains("SRI-format hash"),
1333            "missing sha-hash hint; got: {msg}"
1334        );
1335        assert!(
1336            msg.contains("sktsec_"),
1337            "warning must point users at the correct prefix; got: {msg}"
1338        );
1339        // Token preview must not leak the whole value.
1340        assert!(
1341            !msg.contains("7RcqvqqLZ4iAH"),
1342            "middle of the value must be redacted; got: {msg}"
1343        );
1344    }
1345
1346    #[test]
1347    fn validate_token_shape_flags_too_short() {
1348        let msg = validate_token_shape("sktsec_abc_api").expect("short token must be flagged");
1349        assert!(msg.contains("does not look like a Socket API token"));
1350        assert!(!msg.contains("SRI-format hash"));
1351    }
1352
1353    #[test]
1354    fn validate_token_shape_flags_missing_suffix() {
1355        let raw = format!("sktsec_{}", "x".repeat(50));
1356        assert!(validate_token_shape(&raw).is_some());
1357    }
1358
1359    #[test]
1360    fn looks_like_token_hash_recognizes_sri_prefixes() {
1361        assert!(looks_like_token_hash("sha256-abc"));
1362        assert!(looks_like_token_hash("sha384-abc"));
1363        assert!(looks_like_token_hash("sha512-abc"));
1364        assert!(!looks_like_token_hash("sktsec_xxx_api"));
1365        assert!(!looks_like_token_hash("hello"));
1366        assert!(!looks_like_token_hash(""));
1367    }
1368
1369    // ── binary_url: proxy override must reach blob/diff/package fetches ──
1370    //
1371    // Regression: `fetch_binary` used to re-derive the proxy base from
1372    // `SOCKET_PROXY_URL`/default instead of the client's configured
1373    // `api_url`, so a `--proxy-url` override (which sets `api_url` but no env
1374    // var) was honored for searches yet silently ignored for downloads.
1375
1376    fn proxy_client(api_url: &str) -> ApiClient {
1377        ApiClient::new(ApiClientOptions {
1378            api_url: api_url.into(),
1379            api_token: None,
1380            use_public_proxy: true,
1381            org_slug: None,
1382        })
1383    }
1384
1385    #[test]
1386    fn binary_url_proxy_uses_configured_api_url() {
1387        let client = proxy_client("https://custom.proxy.example");
1388        let (url, use_auth) = client.binary_url("blob", "deadbeef");
1389        assert!(!use_auth);
1390        assert_eq!(url, "https://custom.proxy.example/patch/blob/deadbeef");
1391    }
1392
1393    #[test]
1394    fn binary_url_proxy_covers_diff_and_package() {
1395        let client = proxy_client("https://custom.proxy.example");
1396        assert_eq!(
1397            client.binary_url("diff", "uuid-1").0,
1398            "https://custom.proxy.example/patch/diff/uuid-1"
1399        );
1400        assert_eq!(
1401            client.binary_url("package", "uuid-1").0,
1402            "https://custom.proxy.example/patch/package/uuid-1"
1403        );
1404    }
1405
1406    #[test]
1407    fn binary_url_proxy_trims_trailing_slash() {
1408        // `new()` trims the trailing slash on api_url; binary_url also trims
1409        // defensively so the path never ends up with a doubled separator.
1410        let client = proxy_client("https://custom.proxy.example/");
1411        assert_eq!(
1412            client.binary_url("blob", "x").0,
1413            "https://custom.proxy.example/patch/blob/x"
1414        );
1415    }
1416
1417    #[test]
1418    fn binary_url_authenticated_uses_org_path() {
1419        let client = ApiClient::new(ApiClientOptions {
1420            api_url: "https://api.socket.dev".into(),
1421            api_token: Some("sktsec_x_api".into()),
1422            use_public_proxy: false,
1423            org_slug: Some("my-org".into()),
1424        });
1425        let (url, use_auth) = client.binary_url("diff", "uuid-123");
1426        assert!(use_auth);
1427        assert_eq!(
1428            url,
1429            "https://api.socket.dev/v0/orgs/my-org/patches/diff/uuid-123"
1430        );
1431    }
1432
1433    // ── select_org_slug: deterministic org selection ────────────────────
1434
1435    fn org(slug: &str) -> crate::api::types::OrganizationInfo {
1436        crate::api::types::OrganizationInfo {
1437            id: format!("id-{slug}"),
1438            name: Some(slug.to_string()),
1439            image: None,
1440            plan: "free".into(),
1441            slug: slug.into(),
1442        }
1443    }
1444
1445    #[test]
1446    fn select_org_slug_errors_when_empty() {
1447        assert!(matches!(select_org_slug(vec![]), Err(ApiError::Other(_))));
1448    }
1449
1450    #[test]
1451    fn select_org_slug_returns_sole_org() {
1452        assert_eq!(select_org_slug(vec![org("acme")]).unwrap(), "acme");
1453    }
1454
1455    #[test]
1456    fn select_org_slug_is_deterministic_for_multiple() {
1457        // Regardless of the (HashMap-derived) input order, the
1458        // lexicographically-first slug is chosen so repeated runs agree.
1459        let a = select_org_slug(vec![org("zeta"), org("alpha"), org("mid")]).unwrap();
1460        let b = select_org_slug(vec![org("mid"), org("zeta"), org("alpha")]).unwrap();
1461        assert_eq!(a, "alpha");
1462        assert_eq!(b, "alpha");
1463    }
1464
1465    // ── assemble_batch_from_individual: proxy-fallback aggregation ──────
1466
1467    fn search_response(
1468        purl: &str,
1469        can_access_paid_patches: bool,
1470        patch_uuids: &[&str],
1471    ) -> SearchResponse {
1472        SearchResponse {
1473            patches: patch_uuids
1474                .iter()
1475                .map(|uuid| PatchSearchResult {
1476                    uuid: (*uuid).into(),
1477                    purl: purl.into(),
1478                    published_at: "2024-01-01".into(),
1479                    description: "desc".into(),
1480                    license: "MIT".into(),
1481                    tier: "free".into(),
1482                    vulnerabilities: HashMap::new(),
1483                })
1484                .collect(),
1485            can_access_paid_patches,
1486        }
1487    }
1488
1489    #[test]
1490    fn assemble_batch_collects_patches_per_purl() {
1491        let results = vec![
1492            (
1493                "pkg:npm/a@1".to_string(),
1494                Some(search_response("pkg:npm/a@1", false, &["uuid-a"])),
1495            ),
1496            (
1497                "pkg:npm/b@1".to_string(),
1498                Some(search_response(
1499                    "pkg:npm/b@1",
1500                    false,
1501                    &["uuid-b1", "uuid-b2"],
1502                )),
1503            ),
1504        ];
1505        let batch = assemble_batch_from_individual(results);
1506        assert_eq!(batch.packages.len(), 2);
1507        assert!(!batch.can_access_paid_patches);
1508        let a = batch
1509            .packages
1510            .iter()
1511            .find(|p| p.purl == "pkg:npm/a@1")
1512            .unwrap();
1513        assert_eq!(a.patches.len(), 1);
1514        let b = batch
1515            .packages
1516            .iter()
1517            .find(|p| p.purl == "pkg:npm/b@1")
1518            .unwrap();
1519        assert_eq!(b.patches.len(), 2);
1520    }
1521
1522    #[test]
1523    fn assemble_batch_skips_errored_and_empty_responses() {
1524        // None = query errored; an empty patch list contributes no package.
1525        let results = vec![
1526            ("pkg:npm/err@1".to_string(), None),
1527            (
1528                "pkg:npm/empty@1".to_string(),
1529                Some(search_response("pkg:npm/empty@1", false, &[])),
1530            ),
1531            (
1532                "pkg:npm/ok@1".to_string(),
1533                Some(search_response("pkg:npm/ok@1", false, &["uuid-ok"])),
1534            ),
1535        ];
1536        let batch = assemble_batch_from_individual(results);
1537        // Only the package with at least one patch is listed.
1538        assert_eq!(batch.packages.len(), 1);
1539        assert_eq!(batch.packages[0].purl, "pkg:npm/ok@1");
1540    }
1541
1542    #[test]
1543    fn assemble_batch_aggregates_paid_flag_across_all_responses() {
1544        // OR-aggregation: any response with the flag set flips the aggregate.
1545        let results = vec![
1546            (
1547                "pkg:npm/a@1".to_string(),
1548                Some(search_response("pkg:npm/a@1", false, &["uuid-a"])),
1549            ),
1550            (
1551                "pkg:npm/b@1".to_string(),
1552                Some(search_response("pkg:npm/b@1", true, &["uuid-b"])),
1553            ),
1554        ];
1555        let batch = assemble_batch_from_individual(results);
1556        assert!(batch.can_access_paid_patches);
1557    }
1558
1559    #[test]
1560    fn assemble_batch_keeps_paid_flag_from_empty_patch_response() {
1561        // Regression: the capability flag must survive even when the response
1562        // that carries it has *no* patches. The empty-patch response must not
1563        // be listed as a package, but its `canAccessPaidPatches: true` must
1564        // still flip the aggregate flag — a fused skip would have dropped it.
1565        let results = vec![
1566            (
1567                "pkg:npm/free@1".to_string(),
1568                Some(search_response("pkg:npm/free@1", false, &["uuid-free"])),
1569            ),
1570            (
1571                "pkg:npm/paid-only@1".to_string(),
1572                Some(search_response("pkg:npm/paid-only@1", true, &[])),
1573            ),
1574        ];
1575        let batch = assemble_batch_from_individual(results);
1576        assert!(
1577            batch.can_access_paid_patches,
1578            "paid-access flag from an empty-patch response was dropped"
1579        );
1580        // The empty-patch package must not appear in the listing.
1581        assert_eq!(batch.packages.len(), 1);
1582        assert_eq!(batch.packages[0].purl, "pkg:npm/free@1");
1583    }
1584
1585    // ── convert: title selection is deterministic ───────────────────────
1586
1587    #[test]
1588    fn test_convert_title_deterministic_across_iteration_order() {
1589        // Two vulns, each with a non-empty summary. The title must always be
1590        // drawn from the lexicographically-first GHSA id so the value is
1591        // stable across runs (HashMap iteration order is not).
1592        let mut vulns = HashMap::new();
1593        vulns.insert("GHSA-zzzz".into(), make_vuln("Z summary", "high", vec![]));
1594        vulns.insert("GHSA-aaaa".into(), make_vuln("A summary", "high", vec![]));
1595        let patch = make_patch(vulns, "desc");
1596        let info = convert_search_result_to_batch_info(patch);
1597        assert_eq!(info.title, "A summary");
1598    }
1599}