1use std::collections::HashSet;
2
3use reqwest::header::{self, HeaderMap, HeaderValue};
4use reqwest::StatusCode;
5use serde::Serialize;
6
7use crate::api::types::*;
8use crate::constants::{
9 DEFAULT_PATCH_API_PROXY_URL, DEFAULT_SOCKET_API_URL, USER_AGENT as USER_AGENT_VALUE,
10};
11
12fn is_debug_enabled() -> bool {
14 match std::env::var("SOCKET_PATCH_DEBUG") {
15 Ok(val) => val == "1" || val == "true",
16 Err(_) => false,
17 }
18}
19
20fn debug_log(message: &str) {
22 if is_debug_enabled() {
23 eprintln!("[socket-patch debug] {}", message);
24 }
25}
26
27fn get_severity_order(severity: Option<&str>) -> u8 {
29 match severity.map(|s| s.to_lowercase()).as_deref() {
30 Some("critical") => 0,
31 Some("high") => 1,
32 Some("medium") => 2,
33 Some("low") => 3,
34 _ => 4,
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct ApiClientOptions {
41 pub api_url: String,
42 pub api_token: Option<String>,
43 pub use_public_proxy: bool,
46 pub org_slug: Option<String>,
49}
50
51#[derive(Debug, Clone)]
57pub struct ApiClient {
58 client: reqwest::Client,
59 api_url: String,
60 api_token: Option<String>,
61 use_public_proxy: bool,
62 org_slug: Option<String>,
63}
64
65#[derive(Serialize)]
67struct BatchSearchBody {
68 components: Vec<BatchComponent>,
69}
70
71#[derive(Serialize)]
72struct BatchComponent {
73 purl: String,
74}
75
76impl ApiClient {
77 pub fn new(options: ApiClientOptions) -> Self {
82 let api_url = options.api_url.trim_end_matches('/').to_string();
83
84 let mut default_headers = HeaderMap::new();
85 default_headers.insert(
86 header::USER_AGENT,
87 HeaderValue::from_static(USER_AGENT_VALUE),
88 );
89 default_headers.insert(
90 header::ACCEPT,
91 HeaderValue::from_static("application/json"),
92 );
93
94 if let Some(ref token) = options.api_token {
95 if let Ok(hv) = HeaderValue::from_str(&format!("Bearer {}", token)) {
96 default_headers.insert(header::AUTHORIZATION, hv);
97 }
98 }
99
100 let client = reqwest::Client::builder()
101 .default_headers(default_headers)
102 .build()
103 .expect("failed to build reqwest client");
104
105 Self {
106 client,
107 api_url,
108 api_token: options.api_token,
109 use_public_proxy: options.use_public_proxy,
110 org_slug: options.org_slug,
111 }
112 }
113
114 pub fn api_token(&self) -> Option<&String> {
116 self.api_token.as_ref()
117 }
118
119 pub fn org_slug(&self) -> Option<&String> {
121 self.org_slug.as_ref()
122 }
123
124 async fn get_json<T: serde::de::DeserializeOwned>(
128 &self,
129 path: &str,
130 ) -> Result<Option<T>, ApiError> {
131 let url = format!("{}{}", self.api_url, path);
132 debug_log(&format!("GET {}", url));
133
134 let resp = self
135 .client
136 .get(&url)
137 .send()
138 .await
139 .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
140
141 Self::handle_json_response(resp, self.use_public_proxy).await
142 }
143
144 async fn post_json<T: serde::de::DeserializeOwned, B: Serialize>(
146 &self,
147 path: &str,
148 body: &B,
149 ) -> Result<Option<T>, ApiError> {
150 let url = format!("{}{}", self.api_url, path);
151 debug_log(&format!("POST {}", url));
152
153 let resp = self
154 .client
155 .post(&url)
156 .header(header::CONTENT_TYPE, "application/json")
157 .json(body)
158 .send()
159 .await
160 .map_err(|e| ApiError::Network(format!("Network error: {}", e)))?;
161
162 Self::handle_json_response(resp, self.use_public_proxy).await
163 }
164
165 async fn handle_json_response<T: serde::de::DeserializeOwned>(
167 resp: reqwest::Response,
168 use_public_proxy: bool,
169 ) -> Result<Option<T>, ApiError> {
170 let status = resp.status();
171
172 match status {
173 StatusCode::OK => {
174 let body = resp
175 .json::<T>()
176 .await
177 .map_err(|e| ApiError::Parse(format!("Failed to parse response: {}", e)))?;
178 Ok(Some(body))
179 }
180 StatusCode::NOT_FOUND => Ok(None),
181 StatusCode::UNAUTHORIZED => {
182 Err(ApiError::Unauthorized("Unauthorized: Invalid API token".into()))
183 }
184 StatusCode::FORBIDDEN => {
185 let msg = if use_public_proxy {
186 "Forbidden: This patch is only available to paid subscribers. \
187 Sign up at https://socket.dev to access paid patches."
188 } else {
189 "Forbidden: Access denied. This may be a paid patch or \
190 you may not have access to this organization."
191 };
192 Err(ApiError::Forbidden(msg.into()))
193 }
194 StatusCode::TOO_MANY_REQUESTS => {
195 Err(ApiError::RateLimited(
196 "Rate limit exceeded. Please try again later.".into(),
197 ))
198 }
199 _ => {
200 let text = resp.text().await.unwrap_or_default();
201 Err(ApiError::Other(format!(
202 "API request failed with status {}: {}",
203 status.as_u16(),
204 text
205 )))
206 }
207 }
208 }
209
210 pub async fn fetch_patch(
216 &self,
217 org_slug: Option<&str>,
218 uuid: &str,
219 ) -> Result<Option<PatchResponse>, ApiError> {
220 let path = if self.use_public_proxy {
221 format!("/patch/view/{}", uuid)
222 } else {
223 let slug = org_slug
224 .or(self.org_slug.as_deref())
225 .unwrap_or("default");
226 format!("/v0/orgs/{}/patches/view/{}", slug, uuid)
227 };
228 self.get_json(&path).await
229 }
230
231 pub async fn search_patches_by_cve(
233 &self,
234 org_slug: Option<&str>,
235 cve_id: &str,
236 ) -> Result<SearchResponse, ApiError> {
237 let encoded = urlencoding_encode(cve_id);
238 let path = if self.use_public_proxy {
239 format!("/patch/by-cve/{}", encoded)
240 } else {
241 let slug = org_slug
242 .or(self.org_slug.as_deref())
243 .unwrap_or("default");
244 format!("/v0/orgs/{}/patches/by-cve/{}", slug, encoded)
245 };
246 let result = self.get_json::<SearchResponse>(&path).await?;
247 Ok(result.unwrap_or_else(|| SearchResponse {
248 patches: Vec::new(),
249 can_access_paid_patches: false,
250 }))
251 }
252
253 pub async fn search_patches_by_ghsa(
255 &self,
256 org_slug: Option<&str>,
257 ghsa_id: &str,
258 ) -> Result<SearchResponse, ApiError> {
259 let encoded = urlencoding_encode(ghsa_id);
260 let path = if self.use_public_proxy {
261 format!("/patch/by-ghsa/{}", encoded)
262 } else {
263 let slug = org_slug
264 .or(self.org_slug.as_deref())
265 .unwrap_or("default");
266 format!("/v0/orgs/{}/patches/by-ghsa/{}", slug, encoded)
267 };
268 let result = self.get_json::<SearchResponse>(&path).await?;
269 Ok(result.unwrap_or_else(|| SearchResponse {
270 patches: Vec::new(),
271 can_access_paid_patches: false,
272 }))
273 }
274
275 pub async fn search_patches_by_package(
280 &self,
281 org_slug: Option<&str>,
282 purl: &str,
283 ) -> Result<SearchResponse, ApiError> {
284 let encoded = urlencoding_encode(purl);
285 let path = if self.use_public_proxy {
286 format!("/patch/by-package/{}", encoded)
287 } else {
288 let slug = org_slug
289 .or(self.org_slug.as_deref())
290 .unwrap_or("default");
291 format!("/v0/orgs/{}/patches/by-package/{}", slug, encoded)
292 };
293 let result = self.get_json::<SearchResponse>(&path).await?;
294 Ok(result.unwrap_or_else(|| SearchResponse {
295 patches: Vec::new(),
296 can_access_paid_patches: false,
297 }))
298 }
299
300 pub async fn search_patches_batch(
309 &self,
310 org_slug: Option<&str>,
311 purls: &[String],
312 ) -> Result<BatchSearchResponse, ApiError> {
313 if !self.use_public_proxy {
314 let slug = org_slug
315 .or(self.org_slug.as_deref())
316 .unwrap_or("default");
317 let path = format!("/v0/orgs/{}/patches/batch", slug);
318 let body = BatchSearchBody {
319 components: purls
320 .iter()
321 .map(|p| BatchComponent { purl: p.clone() })
322 .collect(),
323 };
324 let result = self.post_json::<BatchSearchResponse, _>(&path, &body).await?;
325 return Ok(result.unwrap_or_else(|| BatchSearchResponse {
326 packages: Vec::new(),
327 can_access_paid_patches: false,
328 }));
329 }
330
331 self.search_patches_batch_via_individual_queries(purls).await
333 }
334
335 async fn search_patches_batch_via_individual_queries(
341 &self,
342 purls: &[String],
343 ) -> Result<BatchSearchResponse, ApiError> {
344 const CONCURRENCY_LIMIT: usize = 10;
345
346 let mut packages: Vec<BatchPackagePatches> = Vec::new();
347 let mut can_access_paid_patches = false;
348
349 let mut all_results: Vec<(String, Option<SearchResponse>)> = Vec::new();
351
352 for chunk in purls.chunks(CONCURRENCY_LIMIT) {
353 let mut join_set = tokio::task::JoinSet::new();
355
356 for purl in chunk {
357 let purl = purl.clone();
358 let client = self.clone();
359 join_set.spawn(async move {
360 let resp = client.search_patches_by_package(None, &purl).await;
361 match resp {
362 Ok(r) => (purl, Some(r)),
363 Err(e) => {
364 debug_log(&format!("Error fetching patches for {}: {}", purl, e));
365 (purl, None)
366 }
367 }
368 });
369 }
370
371 while let Some(result) = join_set.join_next().await {
372 match result {
373 Ok(pair) => all_results.push(pair),
374 Err(e) => {
375 debug_log(&format!("Task join error: {}", e));
376 }
377 }
378 }
379 }
380
381 for (purl, response) in all_results {
383 let response = match response {
384 Some(r) if !r.patches.is_empty() => r,
385 _ => continue,
386 };
387
388 if response.can_access_paid_patches {
389 can_access_paid_patches = true;
390 }
391
392 let batch_patches: Vec<BatchPatchInfo> = response
393 .patches
394 .into_iter()
395 .map(convert_search_result_to_batch_info)
396 .collect();
397
398 packages.push(BatchPackagePatches {
399 purl,
400 patches: batch_patches,
401 });
402 }
403
404 Ok(BatchSearchResponse {
405 packages,
406 can_access_paid_patches,
407 })
408 }
409
410 pub async fn fetch_organizations(
412 &self,
413 ) -> Result<Vec<crate::api::types::OrganizationInfo>, ApiError> {
414 let path = "/v0/organizations";
415 match self
416 .get_json::<crate::api::types::OrganizationsResponse>(path)
417 .await?
418 {
419 Some(resp) => Ok(resp.organizations.into_values().collect()),
420 None => Ok(Vec::new()),
421 }
422 }
423
424 pub async fn resolve_org_slug(&self) -> Result<String, ApiError> {
430 let orgs = self.fetch_organizations().await?;
431 match orgs.len() {
432 0 => Err(ApiError::Other(
433 "No organizations found for this API token.".into(),
434 )),
435 1 => Ok(orgs.into_iter().next().unwrap().slug),
436 _ => {
437 let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect();
438 let first = orgs[0].slug.clone();
439 eprintln!(
440 "Multiple organizations found: {}. Using \"{}\". \
441 Pass --org to select a different one.",
442 slugs.join(", "),
443 first
444 );
445 Ok(first)
446 }
447 }
448 }
449
450 pub async fn fetch_blob(&self, hash: &str) -> Result<Option<Vec<u8>>, ApiError> {
456 if !is_valid_sha256_hex(hash) {
458 return Err(ApiError::InvalidHash(format!(
459 "Invalid hash format: {}. Expected SHA256 hash (64 hex characters).",
460 hash
461 )));
462 }
463
464 let (url, use_auth) =
465 if self.api_token.is_some() && self.org_slug.is_some() && !self.use_public_proxy {
466 let slug = self.org_slug.as_deref().unwrap();
468 let u = format!("{}/v0/orgs/{}/patches/blob/{}", self.api_url, slug, hash);
469 (u, true)
470 } else {
471 let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL")
473 .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string());
474 let u = format!("{}/patch/blob/{}", proxy_url.trim_end_matches('/'), hash);
475 (u, false)
476 };
477
478 debug_log(&format!("GET blob {}", url));
479
480 let resp = if use_auth {
484 self.client
485 .get(&url)
486 .header(header::ACCEPT, "application/octet-stream")
487 .send()
488 .await
489 } else {
490 let mut headers = HeaderMap::new();
491 headers.insert(
492 header::USER_AGENT,
493 HeaderValue::from_static(USER_AGENT_VALUE),
494 );
495 headers.insert(
496 header::ACCEPT,
497 HeaderValue::from_static("application/octet-stream"),
498 );
499
500 let plain_client = reqwest::Client::builder()
501 .default_headers(headers)
502 .build()
503 .expect("failed to build plain reqwest client");
504
505 plain_client.get(&url).send().await
506 };
507
508 let resp = resp.map_err(|e| {
509 ApiError::Network(format!("Network error fetching blob {}: {}", hash, e))
510 })?;
511
512 let status = resp.status();
513
514 match status {
515 StatusCode::OK => {
516 let bytes = resp.bytes().await.map_err(|e| {
517 ApiError::Network(format!("Error reading blob body for {}: {}", hash, e))
518 })?;
519 Ok(Some(bytes.to_vec()))
520 }
521 StatusCode::NOT_FOUND => Ok(None),
522 _ => {
523 let text = resp.text().await.unwrap_or_default();
524 Err(ApiError::Other(format!(
525 "Failed to fetch blob {}: status {} - {}",
526 hash,
527 status.as_u16(),
528 text,
529 )))
530 }
531 }
532 }
533}
534
535pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
558 let api_token = std::env::var("SOCKET_API_TOKEN")
559 .ok()
560 .filter(|t| !t.is_empty());
561 let resolved_org_slug = org_slug
562 .map(String::from)
563 .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok());
564
565 if api_token.is_none() {
566 let proxy_url = std::env::var("SOCKET_PATCH_PROXY_URL")
567 .unwrap_or_else(|_| DEFAULT_PATCH_API_PROXY_URL.to_string());
568 eprintln!(
569 "No SOCKET_API_TOKEN set. Using public patch API proxy (free patches only)."
570 );
571 let client = ApiClient::new(ApiClientOptions {
572 api_url: proxy_url,
573 api_token: None,
574 use_public_proxy: true,
575 org_slug: None,
576 });
577 return (client, true);
578 }
579
580 let api_url =
581 std::env::var("SOCKET_API_URL").unwrap_or_else(|_| DEFAULT_SOCKET_API_URL.to_string());
582
583 let final_org_slug = if resolved_org_slug.is_some() {
585 resolved_org_slug
586 } else {
587 let temp_client = ApiClient::new(ApiClientOptions {
588 api_url: api_url.clone(),
589 api_token: api_token.clone(),
590 use_public_proxy: false,
591 org_slug: None,
592 });
593 match temp_client.resolve_org_slug().await {
594 Ok(slug) => Some(slug),
595 Err(e) => {
596 eprintln!("Warning: Could not auto-detect organization: {e}");
597 None
598 }
599 }
600 };
601
602 let client = ApiClient::new(ApiClientOptions {
603 api_url,
604 api_token,
605 use_public_proxy: false,
606 org_slug: final_org_slug,
607 });
608 (client, false)
609}
610
611fn urlencoding_encode(input: &str) -> String {
615 let mut out = String::with_capacity(input.len());
617 for byte in input.bytes() {
618 match byte {
619 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
620 out.push(byte as char)
621 }
622 _ => {
623 out.push('%');
624 out.push_str(&format!("{:02X}", byte));
625 }
626 }
627 }
628 out
629}
630
631fn truncate_to_chars(s: &str, max_chars: usize) -> String {
634 if s.chars().count() <= max_chars {
635 return s.to_string();
636 }
637 let truncated: String = s.chars().take(max_chars).collect();
638 format!("{}...", truncated)
639}
640
641fn is_valid_sha256_hex(s: &str) -> bool {
643 s.len() == 64 && s.bytes().all(|b| b.is_ascii_hexdigit())
644}
645
646fn convert_search_result_to_batch_info(patch: PatchSearchResult) -> BatchPatchInfo {
649 let mut cve_ids: Vec<String> = Vec::new();
650 let mut ghsa_ids: Vec<String> = Vec::new();
651 let mut highest_severity: Option<String> = None;
652 let mut title = String::new();
653
654 let mut seen_cves: HashSet<String> = HashSet::new();
655
656 for (ghsa_id, vuln) in &patch.vulnerabilities {
657 ghsa_ids.push(ghsa_id.clone());
658
659 for cve in &vuln.cves {
660 if seen_cves.insert(cve.clone()) {
661 cve_ids.push(cve.clone());
662 }
663 }
664
665 let current_order = get_severity_order(highest_severity.as_deref());
667 let vuln_order = get_severity_order(Some(&vuln.severity));
668 if vuln_order < current_order {
669 highest_severity = Some(vuln.severity.clone());
670 }
671
672 if title.is_empty() && !vuln.summary.is_empty() {
674 title = truncate_to_chars(&vuln.summary, 97);
675 }
676 }
677
678 if title.is_empty() && !patch.description.is_empty() {
680 title = truncate_to_chars(&patch.description, 97);
681 }
682
683 cve_ids.sort();
684 ghsa_ids.sort();
685
686 BatchPatchInfo {
687 uuid: patch.uuid,
688 purl: patch.purl,
689 tier: patch.tier,
690 cve_ids,
691 ghsa_ids,
692 severity: highest_severity,
693 title,
694 }
695}
696
697#[derive(Debug, thiserror::Error)]
701pub enum ApiError {
702 #[error("{0}")]
703 Network(String),
704
705 #[error("{0}")]
706 Parse(String),
707
708 #[error("{0}")]
709 Unauthorized(String),
710
711 #[error("{0}")]
712 Forbidden(String),
713
714 #[error("{0}")]
715 RateLimited(String),
716
717 #[error("{0}")]
718 InvalidHash(String),
719
720 #[error("{0}")]
721 Other(String),
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727 use std::collections::HashMap;
728
729 #[test]
730 fn test_urlencoding_basic() {
731 assert_eq!(urlencoding_encode("hello"), "hello");
732 assert_eq!(urlencoding_encode("a b"), "a%20b");
733 assert_eq!(
734 urlencoding_encode("pkg:npm/lodash@4.17.21"),
735 "pkg%3Anpm%2Flodash%404.17.21"
736 );
737 }
738
739 #[test]
740 fn test_is_valid_sha256_hex() {
741 let valid = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
742 assert!(is_valid_sha256_hex(valid));
743
744 assert!(!is_valid_sha256_hex("abcdef"));
746 assert!(!is_valid_sha256_hex(
748 "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"
749 ));
750 }
751
752 #[test]
753 fn test_severity_order() {
754 assert!(get_severity_order(Some("critical")) < get_severity_order(Some("high")));
755 assert!(get_severity_order(Some("high")) < get_severity_order(Some("medium")));
756 assert!(get_severity_order(Some("medium")) < get_severity_order(Some("low")));
757 assert!(get_severity_order(Some("low")) < get_severity_order(None));
758 assert_eq!(get_severity_order(Some("unknown")), get_severity_order(None));
759 }
760
761 #[test]
762 fn test_convert_search_result_to_batch_info() {
763 let mut vulns = HashMap::new();
764 vulns.insert(
765 "GHSA-1234-5678-9abc".to_string(),
766 VulnerabilityResponse {
767 cves: vec!["CVE-2024-0001".into()],
768 summary: "Test vulnerability".into(),
769 severity: "high".into(),
770 description: "A test vuln".into(),
771 },
772 );
773
774 let patch = PatchSearchResult {
775 uuid: "uuid-1".into(),
776 purl: "pkg:npm/test@1.0.0".into(),
777 published_at: "2024-01-01".into(),
778 description: "A patch".into(),
779 license: "MIT".into(),
780 tier: "free".into(),
781 vulnerabilities: vulns,
782 };
783
784 let info = convert_search_result_to_batch_info(patch);
785 assert_eq!(info.uuid, "uuid-1");
786 assert_eq!(info.cve_ids, vec!["CVE-2024-0001"]);
787 assert_eq!(info.ghsa_ids, vec!["GHSA-1234-5678-9abc"]);
788 assert_eq!(info.severity, Some("high".into()));
789 assert_eq!(info.title, "Test vulnerability");
790 }
791
792 #[tokio::test]
793 async fn test_get_api_client_from_env_no_token() {
794 std::env::remove_var("SOCKET_API_TOKEN");
796 let (client, is_public) = get_api_client_from_env(None).await;
797 assert!(is_public);
798 assert!(client.use_public_proxy);
799 }
800
801 fn make_vuln(summary: &str, severity: &str, cves: Vec<&str>) -> VulnerabilityResponse {
804 VulnerabilityResponse {
805 cves: cves.into_iter().map(String::from).collect(),
806 summary: summary.into(),
807 severity: severity.into(),
808 description: "desc".into(),
809 }
810 }
811
812 fn make_patch(
813 vulns: HashMap<String, VulnerabilityResponse>,
814 description: &str,
815 ) -> PatchSearchResult {
816 PatchSearchResult {
817 uuid: "uuid-1".into(),
818 purl: "pkg:npm/test@1.0.0".into(),
819 published_at: "2024-01-01".into(),
820 description: description.into(),
821 license: "MIT".into(),
822 tier: "free".into(),
823 vulnerabilities: vulns,
824 }
825 }
826
827 #[test]
828 fn test_convert_no_vulnerabilities() {
829 let patch = make_patch(HashMap::new(), "A patch description");
830 let info = convert_search_result_to_batch_info(patch);
831 assert!(info.cve_ids.is_empty());
832 assert!(info.ghsa_ids.is_empty());
833 assert_eq!(info.title, "A patch description");
834 assert!(info.severity.is_none());
835 }
836
837 #[test]
838 fn test_convert_multiple_vulns_picks_highest_severity() {
839 let mut vulns = HashMap::new();
840 vulns.insert(
841 "GHSA-1111".into(),
842 make_vuln("Medium vuln", "medium", vec!["CVE-2024-0001"]),
843 );
844 vulns.insert(
845 "GHSA-2222".into(),
846 make_vuln("Critical vuln", "critical", vec!["CVE-2024-0002"]),
847 );
848 let patch = make_patch(vulns, "desc");
849 let info = convert_search_result_to_batch_info(patch);
850 assert_eq!(info.severity, Some("critical".into()));
851 }
852
853 #[test]
854 fn test_convert_duplicate_cves_deduplicated() {
855 let mut vulns = HashMap::new();
856 vulns.insert(
857 "GHSA-1111".into(),
858 make_vuln("Vuln A", "high", vec!["CVE-2024-0001"]),
859 );
860 vulns.insert(
861 "GHSA-2222".into(),
862 make_vuln("Vuln B", "high", vec!["CVE-2024-0001"]),
863 );
864 let patch = make_patch(vulns, "desc");
865 let info = convert_search_result_to_batch_info(patch);
866 let cve_count = info.cve_ids.iter().filter(|c| *c == "CVE-2024-0001").count();
868 assert_eq!(cve_count, 1);
869 }
870
871 #[test]
872 fn test_convert_title_truncated_at_100() {
873 let long_summary = "x".repeat(150);
874 let mut vulns = HashMap::new();
875 vulns.insert(
876 "GHSA-1111".into(),
877 make_vuln(&long_summary, "high", vec![]),
878 );
879 let patch = make_patch(vulns, "desc");
880 let info = convert_search_result_to_batch_info(patch);
881 assert_eq!(info.title.len(), 100);
883 assert!(info.title.ends_with("..."));
884 }
885
886 #[test]
887 fn test_convert_title_unicode_truncation() {
888 let emoji_summary = "\u{1F600}".repeat(30);
891 let mut vulns = HashMap::new();
892 vulns.insert(
893 "GHSA-1111".into(),
894 make_vuln(&emoji_summary, "high", vec![]),
895 );
896 let patch = make_patch(vulns, "desc");
897 let info = convert_search_result_to_batch_info(patch);
899 assert!(!info.title.is_empty());
900
901 let patch2 = make_patch(HashMap::new(), &"\u{1F600}".repeat(120));
903 let info2 = convert_search_result_to_batch_info(patch2);
904 assert!(info2.title.ends_with("..."));
905 }
906
907 #[test]
908 fn test_convert_title_falls_back_to_description() {
909 let mut vulns = HashMap::new();
910 vulns.insert(
911 "GHSA-1111".into(),
912 make_vuln("", "high", vec![]),
913 );
914 let patch = make_patch(vulns, "Fallback desc");
915 let info = convert_search_result_to_batch_info(patch);
916 assert_eq!(info.title, "Fallback desc");
917 }
918
919 #[test]
920 fn test_convert_empty_summary_and_description() {
921 let mut vulns = HashMap::new();
922 vulns.insert(
923 "GHSA-1111".into(),
924 make_vuln("", "high", vec![]),
925 );
926 let patch = make_patch(vulns, "");
927 let info = convert_search_result_to_batch_info(patch);
928 assert!(info.title.is_empty());
929 }
930
931 #[test]
932 fn test_convert_cves_and_ghsas_sorted() {
933 let mut vulns = HashMap::new();
934 vulns.insert(
935 "GHSA-cccc".into(),
936 make_vuln("V1", "high", vec!["CVE-2024-0003"]),
937 );
938 vulns.insert(
939 "GHSA-aaaa".into(),
940 make_vuln("V2", "high", vec!["CVE-2024-0001"]),
941 );
942 vulns.insert(
943 "GHSA-bbbb".into(),
944 make_vuln("V3", "high", vec!["CVE-2024-0002"]),
945 );
946 let patch = make_patch(vulns, "desc");
947 let info = convert_search_result_to_batch_info(patch);
948 let mut sorted_cves = info.cve_ids.clone();
950 sorted_cves.sort();
951 assert_eq!(info.cve_ids, sorted_cves);
952 let mut sorted_ghsas = info.ghsa_ids.clone();
953 sorted_ghsas.sort();
954 assert_eq!(info.ghsa_ids, sorted_ghsas);
955 }
956
957 #[test]
960 fn test_urlencoding_unicode() {
961 let encoded = urlencoding_encode("café");
963 assert_eq!(encoded, "caf%C3%A9");
964 }
965
966 #[test]
967 fn test_urlencoding_empty() {
968 assert_eq!(urlencoding_encode(""), "");
969 }
970
971 #[test]
972 fn test_urlencoding_all_safe_chars() {
973 let safe = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
975 assert_eq!(urlencoding_encode(safe), safe);
976 }
977
978 #[test]
979 fn test_urlencoding_slash_and_at() {
980 assert_eq!(urlencoding_encode("/"), "%2F");
981 assert_eq!(urlencoding_encode("@"), "%40");
982 }
983
984 #[test]
985 fn test_sha256_uppercase_valid() {
986 let upper = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789";
987 assert!(is_valid_sha256_hex(upper));
988 }
989
990 #[test]
991 fn test_sha256_65_chars_invalid() {
992 let too_long = "a".repeat(65);
993 assert!(!is_valid_sha256_hex(&too_long));
994 }
995
996 #[test]
997 fn test_sha256_63_chars_invalid() {
998 let too_short = "a".repeat(63);
999 assert!(!is_valid_sha256_hex(&too_short));
1000 }
1001
1002 #[test]
1003 fn test_sha256_empty_invalid() {
1004 assert!(!is_valid_sha256_hex(""));
1005 }
1006
1007 #[test]
1008 fn test_sha256_mixed_case_valid() {
1009 let mixed = "aAbBcCdDeEfF0123456789aAbBcCdDeEfF0123456789aAbBcCdDeEfF01234567";
1010 assert_eq!(mixed.len(), 64);
1011 assert!(is_valid_sha256_hex(mixed));
1012 }
1013}