1use std::collections::HashMap;
32use std::env;
33use std::sync::Arc;
34use std::time::Duration;
35
36use chrono::{DateTime, NaiveDate, Utc};
37use reqwest::{Client, StatusCode};
38use serde::{Deserialize, Serialize};
39use tokio::time::sleep;
40
41use crate::api_clients::SimpleEmbedder;
42use crate::ruvector_native::{Domain, SemanticVector};
43use crate::{FrameworkError, Result};
44
45const S2_RATE_LIMIT_MS: u64 = 3000; const S2_WITH_KEY_RATE_LIMIT_MS: u64 = 200; const MAX_RETRIES: u32 = 3;
49const RETRY_DELAY_MS: u64 = 2000;
50const DEFAULT_EMBEDDING_DIM: usize = 384;
51
52#[derive(Debug, Deserialize)]
58struct SearchResponse {
59 #[serde(default)]
60 total: Option<i32>,
61 #[serde(default)]
62 offset: Option<i32>,
63 #[serde(default)]
64 next: Option<i32>,
65 #[serde(default)]
66 data: Vec<PaperData>,
67}
68
69#[derive(Debug, Clone, Deserialize, Serialize)]
71struct PaperData {
72 #[serde(rename = "paperId")]
73 paper_id: String,
74
75 #[serde(default)]
76 title: Option<String>,
77
78 #[serde(rename = "abstract", default)]
79 abstract_text: Option<String>,
80
81 #[serde(default)]
82 year: Option<i32>,
83
84 #[serde(rename = "citationCount", default)]
85 citation_count: Option<i32>,
86
87 #[serde(rename = "referenceCount", default)]
88 reference_count: Option<i32>,
89
90 #[serde(rename = "influentialCitationCount", default)]
91 influential_citation_count: Option<i32>,
92
93 #[serde(default)]
94 authors: Vec<AuthorData>,
95
96 #[serde(rename = "fieldsOfStudy", default)]
97 fields_of_study: Vec<String>,
98
99 #[serde(default)]
100 venue: Option<String>,
101
102 #[serde(rename = "publicationVenue", default)]
103 publication_venue: Option<PublicationVenue>,
104
105 #[serde(default)]
106 url: Option<String>,
107
108 #[serde(rename = "openAccessPdf", default)]
109 open_access_pdf: Option<OpenAccessPdf>,
110}
111
112#[derive(Debug, Clone, Deserialize, Serialize)]
114struct AuthorData {
115 #[serde(rename = "authorId", default)]
116 author_id: Option<String>,
117
118 #[serde(default)]
119 name: Option<String>,
120}
121
122#[derive(Debug, Clone, Deserialize, Serialize)]
124struct PublicationVenue {
125 #[serde(default)]
126 name: Option<String>,
127
128 #[serde(rename = "type", default)]
129 venue_type: Option<String>,
130}
131
132#[derive(Debug, Clone, Deserialize, Serialize)]
134struct OpenAccessPdf {
135 #[serde(default)]
136 url: Option<String>,
137
138 #[serde(default)]
139 status: Option<String>,
140}
141
142#[derive(Debug, Deserialize)]
144struct CitationResponse {
145 #[serde(default)]
146 offset: Option<i32>,
147
148 #[serde(default)]
149 next: Option<i32>,
150
151 #[serde(default)]
152 data: Vec<CitationData>,
153}
154
155#[derive(Debug, Deserialize)]
157struct CitationData {
158 #[serde(rename = "citingPaper", default)]
159 citing_paper: Option<PaperData>,
160
161 #[serde(rename = "citedPaper", default)]
162 cited_paper: Option<PaperData>,
163}
164
165#[derive(Debug, Deserialize)]
167struct AuthorResponse {
168 #[serde(rename = "authorId")]
169 author_id: String,
170
171 #[serde(default)]
172 name: Option<String>,
173
174 #[serde(rename = "paperCount", default)]
175 paper_count: Option<i32>,
176
177 #[serde(rename = "citationCount", default)]
178 citation_count: Option<i32>,
179
180 #[serde(rename = "hIndex", default)]
181 h_index: Option<i32>,
182
183 #[serde(default)]
184 papers: Vec<PaperData>,
185}
186
187pub struct SemanticScholarClient {
204 client: Client,
205 embedder: Arc<SimpleEmbedder>,
206 base_url: String,
207 api_key: Option<String>,
208 rate_limit_delay: Duration,
209}
210
211impl SemanticScholarClient {
212 pub fn new(api_key: Option<String>) -> Self {
226 Self::with_embedding_dim(api_key, DEFAULT_EMBEDDING_DIM)
227 }
228
229 pub fn with_embedding_dim(api_key: Option<String>, embedding_dim: usize) -> Self {
235 let api_key = api_key.or_else(|| env::var("SEMANTIC_SCHOLAR_API_KEY").ok());
237
238 let rate_limit_delay = if api_key.is_some() {
239 Duration::from_millis(S2_WITH_KEY_RATE_LIMIT_MS)
240 } else {
241 Duration::from_millis(S2_RATE_LIMIT_MS)
242 };
243
244 Self {
245 client: Client::builder()
246 .user_agent("RuVector-Discovery/1.0")
247 .timeout(Duration::from_secs(30))
248 .build()
249 .expect("Failed to create HTTP client"),
250 embedder: Arc::new(SimpleEmbedder::new(embedding_dim)),
251 base_url: "https://api.semanticscholar.org/graph/v1".to_string(),
252 api_key,
253 rate_limit_delay,
254 }
255 }
256
257 pub async fn search_papers(&self, query: &str, limit: usize) -> Result<Vec<SemanticVector>> {
268 let limit = limit.min(100); let encoded_query = urlencoding::encode(query);
270
271 let url = format!(
272 "{}/paper/search?query={}&limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,influentialCitationCount,authors,fieldsOfStudy,venue,publicationVenue,url,openAccessPdf",
273 self.base_url, encoded_query, limit
274 );
275
276 let response: SearchResponse = self.fetch_json(&url).await?;
277
278 let mut vectors = Vec::new();
279 for paper in response.data {
280 if let Some(vector) = self.paper_to_vector(paper) {
281 vectors.push(vector);
282 }
283 }
284
285 Ok(vectors)
286 }
287
288 pub async fn get_paper(&self, paper_id: &str) -> Result<Option<SemanticVector>> {
298 let url = format!(
299 "{}/paper/{}?fields=paperId,title,abstract,year,citationCount,referenceCount,influentialCitationCount,authors,fieldsOfStudy,venue,publicationVenue,url,openAccessPdf",
300 self.base_url, paper_id
301 );
302
303 let paper: PaperData = self.fetch_json(&url).await?;
304 Ok(self.paper_to_vector(paper))
305 }
306
307 pub async fn get_citations(&self, paper_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
318 let limit = limit.min(1000); let url = format!(
321 "{}/paper/{}/citations?limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,authors,fieldsOfStudy,venue,url",
322 self.base_url, paper_id, limit
323 );
324
325 let response: CitationResponse = self.fetch_json(&url).await?;
326
327 let mut vectors = Vec::new();
328 for citation in response.data {
329 if let Some(citing_paper) = citation.citing_paper {
330 if let Some(vector) = self.paper_to_vector(citing_paper) {
331 vectors.push(vector);
332 }
333 }
334 }
335
336 Ok(vectors)
337 }
338
339 pub async fn get_references(&self, paper_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
350 let limit = limit.min(1000); let url = format!(
353 "{}/paper/{}/references?limit={}&fields=paperId,title,abstract,year,citationCount,referenceCount,authors,fieldsOfStudy,venue,url",
354 self.base_url, paper_id, limit
355 );
356
357 let response: CitationResponse = self.fetch_json(&url).await?;
358
359 let mut vectors = Vec::new();
360 for reference in response.data {
361 if let Some(cited_paper) = reference.cited_paper {
362 if let Some(vector) = self.paper_to_vector(cited_paper) {
363 vectors.push(vector);
364 }
365 }
366 }
367
368 Ok(vectors)
369 }
370
371 pub async fn search_by_field(&self, field_of_study: &str, limit: usize) -> Result<Vec<SemanticVector>> {
383 let query = format!("fieldsOfStudy:{}", field_of_study);
385 self.search_papers(&query, limit).await
386 }
387
388 pub async fn get_author(&self, author_id: &str) -> Result<Vec<SemanticVector>> {
398 let url = format!(
399 "{}/author/{}?fields=authorId,name,paperCount,citationCount,hIndex,papers.paperId,papers.title,papers.abstract,papers.year,papers.citationCount,papers.fieldsOfStudy",
400 self.base_url, author_id
401 );
402
403 let author: AuthorResponse = self.fetch_json(&url).await?;
404
405 let mut vectors = Vec::new();
406 for paper in author.papers {
407 if let Some(vector) = self.paper_to_vector(paper) {
408 vectors.push(vector);
409 }
410 }
411
412 Ok(vectors)
413 }
414
415 pub async fn search_recent(&self, query: &str, year_min: i32) -> Result<Vec<SemanticVector>> {
427 let all_results = self.search_papers(query, 100).await?;
428
429 Ok(all_results
431 .into_iter()
432 .filter(|v| {
433 v.metadata
434 .get("year")
435 .and_then(|y| y.parse::<i32>().ok())
436 .map(|year| year >= year_min)
437 .unwrap_or(false)
438 })
439 .collect())
440 }
441
442 pub async fn build_citation_graph(
460 &self,
461 paper_id: &str,
462 max_citations: usize,
463 max_references: usize,
464 ) -> Result<(Option<SemanticVector>, Vec<SemanticVector>, Vec<SemanticVector>)> {
465 let paper_result = self.get_paper(paper_id);
467 let citations_result = self.get_citations(paper_id, max_citations);
468 let references_result = self.get_references(paper_id, max_references);
469
470 let paper = paper_result.await?;
472 sleep(self.rate_limit_delay).await;
473
474 let citations = citations_result.await?;
475 sleep(self.rate_limit_delay).await;
476
477 let references = references_result.await?;
478
479 Ok((paper, citations, references))
480 }
481
482 fn paper_to_vector(&self, paper: PaperData) -> Option<SemanticVector> {
484 let title = paper.title.clone().unwrap_or_default();
485 let abstract_text = paper.abstract_text.clone().unwrap_or_default();
486
487 if title.is_empty() {
489 return None;
490 }
491
492 let combined_text = format!("{} {}", title, abstract_text);
494 let embedding = self.embedder.embed_text(&combined_text);
495
496 let timestamp = paper.year
498 .and_then(|y| NaiveDate::from_ymd_opt(y, 1, 1))
499 .map(|d| DateTime::from_naive_utc_and_offset(d.and_hms_opt(0, 0, 0).unwrap(), Utc))
500 .unwrap_or_else(Utc::now);
501
502 let mut metadata = HashMap::new();
504 metadata.insert("paper_id".to_string(), paper.paper_id.clone());
505 metadata.insert("title".to_string(), title);
506
507 if !abstract_text.is_empty() {
508 metadata.insert("abstract".to_string(), abstract_text);
509 }
510
511 if let Some(year) = paper.year {
512 metadata.insert("year".to_string(), year.to_string());
513 }
514
515 if let Some(count) = paper.citation_count {
516 metadata.insert("citationCount".to_string(), count.to_string());
517 }
518
519 if let Some(count) = paper.reference_count {
520 metadata.insert("referenceCount".to_string(), count.to_string());
521 }
522
523 if let Some(count) = paper.influential_citation_count {
524 metadata.insert("influentialCitationCount".to_string(), count.to_string());
525 }
526
527 let authors = paper
529 .authors
530 .iter()
531 .filter_map(|a| a.name.as_ref())
532 .cloned()
533 .collect::<Vec<_>>()
534 .join(", ");
535 if !authors.is_empty() {
536 metadata.insert("authors".to_string(), authors);
537 }
538
539 if !paper.fields_of_study.is_empty() {
541 metadata.insert("fieldsOfStudy".to_string(), paper.fields_of_study.join(", "));
542 }
543
544 if let Some(venue) = paper.venue.or_else(|| paper.publication_venue.and_then(|pv| pv.name)) {
546 metadata.insert("venue".to_string(), venue);
547 }
548
549 if let Some(url) = paper.url {
551 metadata.insert("url".to_string(), url);
552 } else {
553 metadata.insert(
554 "url".to_string(),
555 format!("https://www.semanticscholar.org/paper/{}", paper.paper_id),
556 );
557 }
558
559 if let Some(pdf) = paper.open_access_pdf.and_then(|p| p.url) {
561 metadata.insert("pdf_url".to_string(), pdf);
562 }
563
564 metadata.insert("source".to_string(), "semantic_scholar".to_string());
565
566 Some(SemanticVector {
567 id: format!("s2:{}", paper.paper_id),
568 embedding,
569 domain: Domain::Research,
570 timestamp,
571 metadata,
572 })
573 }
574
575 async fn fetch_json<T: for<'de> Deserialize<'de>>(&self, url: &str) -> Result<T> {
577 sleep(self.rate_limit_delay).await;
579
580 let response = self.fetch_with_retry(url).await?;
581 let json = response.json::<T>().await?;
582
583 Ok(json)
584 }
585
586 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
588 let mut retries = 0;
589 loop {
590 let mut request = self.client.get(url);
591
592 if let Some(ref api_key) = self.api_key {
594 request = request.header("x-api-key", api_key);
595 }
596
597 match request.send().await {
598 Ok(response) => {
599 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
600 retries += 1;
601 let delay = RETRY_DELAY_MS * (2_u64.pow(retries - 1)); tracing::warn!(
603 "Rate limited by Semantic Scholar, retrying in {}ms",
604 delay
605 );
606 sleep(Duration::from_millis(delay)).await;
607 continue;
608 }
609 if !response.status().is_success() {
610 return Err(FrameworkError::Network(
611 reqwest::Error::from(response.error_for_status().unwrap_err()),
612 ));
613 }
614 return Ok(response);
615 }
616 Err(_) if retries < MAX_RETRIES => {
617 retries += 1;
618 let delay = RETRY_DELAY_MS * (2_u64.pow(retries - 1)); tracing::warn!("Request failed, retrying ({}/{}) in {}ms", retries, MAX_RETRIES, delay);
620 sleep(Duration::from_millis(delay)).await;
621 }
622 Err(e) => return Err(FrameworkError::Network(e)),
623 }
624 }
625 }
626}
627
628impl Default for SemanticScholarClient {
629 fn default() -> Self {
630 Self::new(None)
631 }
632}
633
634#[cfg(test)]
639mod tests {
640 use super::*;
641
642 #[test]
643 fn test_client_creation() {
644 let client = SemanticScholarClient::new(None);
645 assert_eq!(client.base_url, "https://api.semanticscholar.org/graph/v1");
646 assert_eq!(client.rate_limit_delay, Duration::from_millis(S2_RATE_LIMIT_MS));
647 }
648
649 #[test]
650 fn test_client_with_api_key() {
651 let client = SemanticScholarClient::new(Some("test-key".to_string()));
652 assert_eq!(client.api_key, Some("test-key".to_string()));
653 assert_eq!(client.rate_limit_delay, Duration::from_millis(S2_WITH_KEY_RATE_LIMIT_MS));
654 }
655
656 #[test]
657 fn test_custom_embedding_dim() {
658 let client = SemanticScholarClient::with_embedding_dim(None, 512);
659 let embedding = client.embedder.embed_text("test");
660 assert_eq!(embedding.len(), 512);
661 }
662
663 #[test]
664 fn test_paper_to_vector() {
665 let client = SemanticScholarClient::new(None);
666
667 let paper = PaperData {
668 paper_id: "649def34f8be52c8b66281af98ae884c09aef38b".to_string(),
669 title: Some("Attention Is All You Need".to_string()),
670 abstract_text: Some("The dominant sequence transduction models...".to_string()),
671 year: Some(2017),
672 citation_count: Some(50000),
673 reference_count: Some(35),
674 influential_citation_count: Some(5000),
675 authors: vec![
676 AuthorData {
677 author_id: Some("1741101".to_string()),
678 name: Some("Ashish Vaswani".to_string()),
679 },
680 AuthorData {
681 author_id: Some("1699545".to_string()),
682 name: Some("Noam Shazeer".to_string()),
683 },
684 ],
685 fields_of_study: vec!["Computer Science".to_string(), "Mathematics".to_string()],
686 venue: Some("NeurIPS".to_string()),
687 publication_venue: None,
688 url: Some("https://arxiv.org/abs/1706.03762".to_string()),
689 open_access_pdf: Some(OpenAccessPdf {
690 url: Some("https://arxiv.org/pdf/1706.03762.pdf".to_string()),
691 status: Some("GREEN".to_string()),
692 }),
693 };
694
695 let vector = client.paper_to_vector(paper);
696 assert!(vector.is_some());
697
698 let v = vector.unwrap();
699 assert_eq!(v.id, "s2:649def34f8be52c8b66281af98ae884c09aef38b");
700 assert_eq!(v.domain, Domain::Research);
701 assert_eq!(v.metadata.get("paper_id").unwrap(), "649def34f8be52c8b66281af98ae884c09aef38b");
702 assert_eq!(v.metadata.get("title").unwrap(), "Attention Is All You Need");
703 assert_eq!(v.metadata.get("year").unwrap(), "2017");
704 assert_eq!(v.metadata.get("citationCount").unwrap(), "50000");
705 assert_eq!(v.metadata.get("referenceCount").unwrap(), "35");
706 assert_eq!(v.metadata.get("authors").unwrap(), "Ashish Vaswani, Noam Shazeer");
707 assert_eq!(v.metadata.get("fieldsOfStudy").unwrap(), "Computer Science, Mathematics");
708 assert_eq!(v.metadata.get("venue").unwrap(), "NeurIPS");
709 assert!(v.metadata.contains_key("pdf_url"));
710 }
711
712 #[test]
713 fn test_paper_to_vector_minimal() {
714 let client = SemanticScholarClient::new(None);
715
716 let paper = PaperData {
717 paper_id: "test123".to_string(),
718 title: Some("Minimal Paper".to_string()),
719 abstract_text: None,
720 year: None,
721 citation_count: None,
722 reference_count: None,
723 influential_citation_count: None,
724 authors: vec![],
725 fields_of_study: vec![],
726 venue: None,
727 publication_venue: None,
728 url: None,
729 open_access_pdf: None,
730 };
731
732 let vector = client.paper_to_vector(paper);
733 assert!(vector.is_some());
734
735 let v = vector.unwrap();
736 assert_eq!(v.id, "s2:test123");
737 assert_eq!(v.metadata.get("title").unwrap(), "Minimal Paper");
738 assert!(v.metadata.get("url").unwrap().contains("semanticscholar.org"));
739 }
740
741 #[test]
742 fn test_paper_without_title() {
743 let client = SemanticScholarClient::new(None);
744
745 let paper = PaperData {
746 paper_id: "test456".to_string(),
747 title: None,
748 abstract_text: Some("Has abstract but no title".to_string()),
749 year: Some(2020),
750 citation_count: None,
751 reference_count: None,
752 influential_citation_count: None,
753 authors: vec![],
754 fields_of_study: vec![],
755 venue: None,
756 publication_venue: None,
757 url: None,
758 open_access_pdf: None,
759 };
760
761 let vector = client.paper_to_vector(paper);
763 assert!(vector.is_none());
764 }
765
766 #[tokio::test]
767 #[ignore] async fn test_search_papers_integration() {
769 let client = SemanticScholarClient::new(None);
770 let results = client.search_papers("machine learning", 5).await;
771 assert!(results.is_ok());
772
773 let vectors = results.unwrap();
774 assert!(vectors.len() <= 5);
775
776 if !vectors.is_empty() {
777 let first = &vectors[0];
778 assert!(first.id.starts_with("s2:"));
779 assert_eq!(first.domain, Domain::Research);
780 assert!(first.metadata.contains_key("title"));
781 assert!(first.metadata.contains_key("paper_id"));
782 }
783 }
784
785 #[tokio::test]
786 #[ignore] async fn test_get_paper_integration() {
788 let client = SemanticScholarClient::new(None);
789
790 let result = client.get_paper("649def34f8be52c8b66281af98ae884c09aef38b").await;
792 assert!(result.is_ok());
793
794 let paper = result.unwrap();
795 assert!(paper.is_some());
796
797 let p = paper.unwrap();
798 assert_eq!(p.id, "s2:649def34f8be52c8b66281af98ae884c09aef38b");
799 assert!(p.metadata.get("title").unwrap().contains("Attention"));
800 }
801
802 #[tokio::test]
803 #[ignore] async fn test_get_citations_integration() {
805 let client = SemanticScholarClient::new(None);
806
807 let result = client.get_citations("649def34f8be52c8b66281af98ae884c09aef38b", 10).await;
809 assert!(result.is_ok());
810
811 let citations = result.unwrap();
812 assert!(citations.len() <= 10);
813 }
814
815 #[tokio::test]
816 #[ignore] async fn test_search_by_field_integration() {
818 let client = SemanticScholarClient::new(None);
819 let results = client.search_by_field("Computer Science", 5).await;
820 assert!(results.is_ok());
821
822 let vectors = results.unwrap();
823 assert!(vectors.len() <= 5);
824 }
825
826 #[tokio::test]
827 #[ignore] async fn test_build_citation_graph_integration() {
829 let client = SemanticScholarClient::new(None);
830
831 let result = client.build_citation_graph(
832 "649def34f8be52c8b66281af98ae884c09aef38b",
833 5,
834 5
835 ).await;
836 assert!(result.is_ok());
837
838 let (paper, citations, references) = result.unwrap();
839 assert!(paper.is_some());
840 }
841}