1use std::collections::HashMap;
33use std::time::Duration;
34
35use chrono::{DateTime, NaiveDate, Utc};
36use reqwest::{Client, StatusCode};
37use serde::Deserialize;
38use tokio::time::sleep;
39
40use crate::api_clients::SimpleEmbedder;
41use crate::ruvector_native::{Domain, SemanticVector};
42use crate::{FrameworkError, Result};
43
44const CROSSREF_RATE_LIMIT_MS: u64 = 1000; const MAX_RETRIES: u32 = 3;
47const RETRY_DELAY_MS: u64 = 2000;
48const DEFAULT_EMBEDDING_DIM: usize = 384;
49
50#[derive(Debug, Deserialize)]
56struct CrossRefResponse {
57 #[serde(default)]
58 message: CrossRefMessage,
59}
60
61#[derive(Debug, Default, Deserialize)]
62struct CrossRefMessage {
63 #[serde(default)]
64 items: Vec<CrossRefWork>,
65 #[serde(rename = "total-results", default)]
66 total_results: Option<u64>,
67}
68
69#[derive(Debug, Deserialize)]
71struct CrossRefWork {
72 #[serde(rename = "DOI")]
73 doi: String,
74 #[serde(default)]
75 title: Vec<String>,
76 #[serde(rename = "abstract", default)]
77 abstract_text: Option<String>,
78 #[serde(default)]
79 author: Vec<CrossRefAuthor>,
80 #[serde(rename = "published-print", default)]
81 published_print: Option<CrossRefDate>,
82 #[serde(rename = "published-online", default)]
83 published_online: Option<CrossRefDate>,
84 #[serde(rename = "container-title", default)]
85 container_title: Vec<String>,
86 #[serde(rename = "is-referenced-by-count", default)]
87 citation_count: Option<u64>,
88 #[serde(rename = "references-count", default)]
89 references_count: Option<u64>,
90 #[serde(default)]
91 subject: Vec<String>,
92 #[serde(default)]
93 funder: Vec<CrossRefFunder>,
94 #[serde(rename = "type", default)]
95 work_type: Option<String>,
96 #[serde(default)]
97 publisher: Option<String>,
98}
99
100#[derive(Debug, Deserialize)]
101struct CrossRefAuthor {
102 #[serde(default)]
103 given: Option<String>,
104 #[serde(default)]
105 family: Option<String>,
106 #[serde(default)]
107 name: Option<String>,
108 #[serde(rename = "ORCID", default)]
109 orcid: Option<String>,
110}
111
112#[derive(Debug, Deserialize)]
113struct CrossRefDate {
114 #[serde(rename = "date-parts", default)]
115 date_parts: Vec<Vec<i32>>,
116}
117
118#[derive(Debug, Deserialize)]
119struct CrossRefFunder {
120 #[serde(default)]
121 name: Option<String>,
122 #[serde(rename = "DOI", default)]
123 doi: Option<String>,
124}
125
126pub struct CrossRefClient {
140 client: Client,
141 embedder: SimpleEmbedder,
142 base_url: String,
143 polite_email: Option<String>,
144}
145
146impl CrossRefClient {
147 pub fn new(polite_email: Option<String>) -> Self {
157 Self::with_embedding_dim(polite_email, DEFAULT_EMBEDDING_DIM)
158 }
159
160 pub fn with_embedding_dim(polite_email: Option<String>, embedding_dim: usize) -> Self {
166 let user_agent = if let Some(ref email) = polite_email {
167 format!("RuVector-Discovery/1.0 (mailto:{})", email)
168 } else {
169 "RuVector-Discovery/1.0".to_string()
170 };
171
172 Self {
173 client: Client::builder()
174 .user_agent(&user_agent)
175 .timeout(Duration::from_secs(30))
176 .build()
177 .expect("Failed to create HTTP client"),
178 embedder: SimpleEmbedder::new(embedding_dim),
179 base_url: "https://api.crossref.org".to_string(),
180 polite_email,
181 }
182 }
183
184 pub async fn search_works(&self, query: &str, limit: usize) -> Result<Vec<SemanticVector>> {
195 let encoded_query = urlencoding::encode(query);
196 let mut url = format!(
197 "{}/works?query={}&rows={}",
198 self.base_url, encoded_query, limit
199 );
200
201 if let Some(email) = &self.polite_email {
202 url.push_str(&format!("&mailto={}", email));
203 }
204
205 self.fetch_and_parse(&url).await
206 }
207
208 pub async fn get_work(&self, doi: &str) -> Result<Option<SemanticVector>> {
218 let normalized_doi = Self::normalize_doi(doi);
219 let mut url = format!("{}/works/{}", self.base_url, normalized_doi);
220
221 if let Some(email) = &self.polite_email {
222 url.push_str(&format!("?mailto={}", email));
223 }
224
225 sleep(Duration::from_millis(CROSSREF_RATE_LIMIT_MS)).await;
226
227 let response = self.fetch_with_retry(&url).await?;
228 let json_response: CrossRefResponse = response.json().await?;
229
230 if let Some(work) = json_response.message.items.into_iter().next() {
231 Ok(Some(self.work_to_vector(work)))
232 } else {
233 Ok(None)
234 }
235 }
236
237 pub async fn search_by_funder(&self, funder_id: &str, limit: usize) -> Result<Vec<SemanticVector>> {
249 let mut url = format!(
250 "{}/funders/{}/works?rows={}",
251 self.base_url, funder_id, limit
252 );
253
254 if let Some(email) = &self.polite_email {
255 url.push_str(&format!("&mailto={}", email));
256 }
257
258 self.fetch_and_parse(&url).await
259 }
260
261 pub async fn search_by_subject(&self, subject: &str, limit: usize) -> Result<Vec<SemanticVector>> {
272 let encoded_subject = urlencoding::encode(subject);
273 let mut url = format!(
274 "{}/works?filter=has-subject:true&query.subject={}&rows={}",
275 self.base_url, encoded_subject, limit
276 );
277
278 if let Some(email) = &self.polite_email {
279 url.push_str(&format!("&mailto={}", email));
280 }
281
282 self.fetch_and_parse(&url).await
283 }
284
285 pub async fn get_citations(&self, doi: &str, limit: usize) -> Result<Vec<SemanticVector>> {
296 let normalized_doi = Self::normalize_doi(doi);
297 let mut url = format!(
298 "{}/works?filter=references:{}&rows={}",
299 self.base_url, normalized_doi, limit
300 );
301
302 if let Some(email) = &self.polite_email {
303 url.push_str(&format!("&mailto={}", email));
304 }
305
306 self.fetch_and_parse(&url).await
307 }
308
309 pub async fn search_recent(&self, query: &str, from_date: &str, limit: usize) -> Result<Vec<SemanticVector>> {
321 let encoded_query = urlencoding::encode(query);
322 let mut url = format!(
323 "{}/works?query={}&filter=from-pub-date:{}&rows={}",
324 self.base_url, encoded_query, from_date, limit
325 );
326
327 if let Some(email) = &self.polite_email {
328 url.push_str(&format!("&mailto={}", email));
329 }
330
331 self.fetch_and_parse(&url).await
332 }
333
334 pub async fn search_by_type(
355 &self,
356 work_type: &str,
357 query: Option<&str>,
358 limit: usize,
359 ) -> Result<Vec<SemanticVector>> {
360 let mut url = format!(
361 "{}/works?filter=type:{}&rows={}",
362 self.base_url, work_type, limit
363 );
364
365 if let Some(q) = query {
366 let encoded_query = urlencoding::encode(q);
367 url.push_str(&format!("&query={}", encoded_query));
368 }
369
370 if let Some(email) = &self.polite_email {
371 url.push_str(&format!("&mailto={}", email));
372 }
373
374 self.fetch_and_parse(&url).await
375 }
376
377 async fn fetch_and_parse(&self, url: &str) -> Result<Vec<SemanticVector>> {
379 sleep(Duration::from_millis(CROSSREF_RATE_LIMIT_MS)).await;
381
382 let response = self.fetch_with_retry(url).await?;
383 let crossref_response: CrossRefResponse = response.json().await?;
384
385 let vectors = crossref_response
387 .message
388 .items
389 .into_iter()
390 .map(|work| self.work_to_vector(work))
391 .collect();
392
393 Ok(vectors)
394 }
395
396 fn work_to_vector(&self, work: CrossRefWork) -> SemanticVector {
398 let title = work
400 .title
401 .first()
402 .cloned()
403 .unwrap_or_else(|| "Untitled".to_string());
404
405 let abstract_text = work.abstract_text.unwrap_or_default();
407
408 let timestamp = work
410 .published_print
411 .or(work.published_online)
412 .and_then(|date| Self::parse_crossref_date(&date))
413 .unwrap_or_else(Utc::now);
414
415 let combined_text = if abstract_text.is_empty() {
417 title.clone()
418 } else {
419 format!("{} {}", title, abstract_text)
420 };
421 let embedding = self.embedder.embed_text(&combined_text);
422
423 let authors = work
425 .author
426 .iter()
427 .map(|a| Self::format_author_name(a))
428 .collect::<Vec<_>>()
429 .join("; ");
430
431 let journal = work
433 .container_title
434 .first()
435 .cloned()
436 .unwrap_or_default();
437
438 let subjects = work.subject.join(", ");
440
441 let funders = work
443 .funder
444 .iter()
445 .filter_map(|f| f.name.clone())
446 .collect::<Vec<_>>()
447 .join(", ");
448
449 let mut metadata = HashMap::new();
451 metadata.insert("doi".to_string(), work.doi.clone());
452 metadata.insert("title".to_string(), title);
453 metadata.insert("abstract".to_string(), abstract_text);
454 metadata.insert("authors".to_string(), authors);
455 metadata.insert("journal".to_string(), journal);
456 metadata.insert("subjects".to_string(), subjects);
457 metadata.insert(
458 "citation_count".to_string(),
459 work.citation_count.unwrap_or(0).to_string(),
460 );
461 metadata.insert(
462 "references_count".to_string(),
463 work.references_count.unwrap_or(0).to_string(),
464 );
465 metadata.insert("funders".to_string(), funders);
466 metadata.insert(
467 "type".to_string(),
468 work.work_type.unwrap_or_else(|| "unknown".to_string()),
469 );
470 if let Some(publisher) = work.publisher {
471 metadata.insert("publisher".to_string(), publisher);
472 }
473 metadata.insert("source".to_string(), "crossref".to_string());
474
475 SemanticVector {
476 id: format!("doi:{}", work.doi),
477 embedding,
478 domain: Domain::Research,
479 timestamp,
480 metadata,
481 }
482 }
483
484 fn parse_crossref_date(date: &CrossRefDate) -> Option<DateTime<Utc>> {
486 if let Some(parts) = date.date_parts.first() {
487 if parts.is_empty() {
488 return None;
489 }
490
491 let year = parts[0];
492 let month = parts.get(1).copied().unwrap_or(1).max(1).min(12);
493 let day = parts.get(2).copied().unwrap_or(1).max(1).min(31);
494
495 NaiveDate::from_ymd_opt(year, month as u32, day as u32)
496 .and_then(|d| d.and_hms_opt(0, 0, 0))
497 .map(|dt| dt.and_utc())
498 } else {
499 None
500 }
501 }
502
503 fn format_author_name(author: &CrossRefAuthor) -> String {
505 if let Some(name) = &author.name {
506 name.clone()
507 } else {
508 let given = author.given.as_deref().unwrap_or("");
509 let family = author.family.as_deref().unwrap_or("");
510 format!("{} {}", given, family).trim().to_string()
511 }
512 }
513
514 fn normalize_doi(doi: &str) -> String {
516 doi.trim()
517 .trim_start_matches("http://")
518 .trim_start_matches("https://")
519 .trim_start_matches("doi.org/")
520 .trim_start_matches("dx.doi.org/")
521 .to_string()
522 }
523
524 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
526 let mut retries = 0;
527 loop {
528 match self.client.get(url).send().await {
529 Ok(response) => {
530 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
531 {
532 retries += 1;
533 tracing::warn!(
534 "Rate limited by CrossRef, retrying in {}ms",
535 RETRY_DELAY_MS * retries as u64
536 );
537 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
538 continue;
539 }
540 if !response.status().is_success() {
541 return Err(FrameworkError::Network(
542 reqwest::Error::from(response.error_for_status().unwrap_err()),
543 ));
544 }
545 return Ok(response);
546 }
547 Err(_) if retries < MAX_RETRIES => {
548 retries += 1;
549 tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
550 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
551 }
552 Err(e) => return Err(FrameworkError::Network(e)),
553 }
554 }
555 }
556}
557
558impl Default for CrossRefClient {
559 fn default() -> Self {
560 Self::new(None)
561 }
562}
563
564#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_crossref_client_creation() {
574 let client = CrossRefClient::new(Some("test@example.com".to_string()));
575 assert_eq!(client.base_url, "https://api.crossref.org");
576 assert_eq!(client.polite_email, Some("test@example.com".to_string()));
577 }
578
579 #[test]
580 fn test_crossref_client_without_email() {
581 let client = CrossRefClient::new(None);
582 assert_eq!(client.base_url, "https://api.crossref.org");
583 assert_eq!(client.polite_email, None);
584 }
585
586 #[test]
587 fn test_custom_embedding_dim() {
588 let client = CrossRefClient::with_embedding_dim(None, 512);
589 let embedding = client.embedder.embed_text("test");
590 assert_eq!(embedding.len(), 512);
591 }
592
593 #[test]
594 fn test_normalize_doi() {
595 assert_eq!(
596 CrossRefClient::normalize_doi("10.1038/nature12373"),
597 "10.1038/nature12373"
598 );
599 assert_eq!(
600 CrossRefClient::normalize_doi("http://doi.org/10.1038/nature12373"),
601 "10.1038/nature12373"
602 );
603 assert_eq!(
604 CrossRefClient::normalize_doi("https://dx.doi.org/10.1038/nature12373"),
605 "10.1038/nature12373"
606 );
607 assert_eq!(
608 CrossRefClient::normalize_doi(" 10.1038/nature12373 "),
609 "10.1038/nature12373"
610 );
611 }
612
613 #[test]
614 fn test_parse_crossref_date() {
615 let date1 = CrossRefDate {
617 date_parts: vec![vec![2024, 3, 15]],
618 };
619 let parsed1 = CrossRefClient::parse_crossref_date(&date1);
620 assert!(parsed1.is_some());
621 let dt1 = parsed1.unwrap();
622 assert_eq!(dt1.format("%Y-%m-%d").to_string(), "2024-03-15");
623
624 let date2 = CrossRefDate {
626 date_parts: vec![vec![2024, 3]],
627 };
628 let parsed2 = CrossRefClient::parse_crossref_date(&date2);
629 assert!(parsed2.is_some());
630
631 let date3 = CrossRefDate {
633 date_parts: vec![vec![2024]],
634 };
635 let parsed3 = CrossRefClient::parse_crossref_date(&date3);
636 assert!(parsed3.is_some());
637
638 let date4 = CrossRefDate {
640 date_parts: vec![vec![]],
641 };
642 let parsed4 = CrossRefClient::parse_crossref_date(&date4);
643 assert!(parsed4.is_none());
644 }
645
646 #[test]
647 fn test_format_author_name() {
648 let author1 = CrossRefAuthor {
650 given: Some("John".to_string()),
651 family: Some("Doe".to_string()),
652 name: None,
653 orcid: None,
654 };
655 assert_eq!(
656 CrossRefClient::format_author_name(&author1),
657 "John Doe"
658 );
659
660 let author2 = CrossRefAuthor {
662 given: None,
663 family: None,
664 name: Some("Jane Smith".to_string()),
665 orcid: None,
666 };
667 assert_eq!(
668 CrossRefClient::format_author_name(&author2),
669 "Jane Smith"
670 );
671
672 let author3 = CrossRefAuthor {
674 given: None,
675 family: Some("Einstein".to_string()),
676 name: None,
677 orcid: None,
678 };
679 assert_eq!(
680 CrossRefClient::format_author_name(&author3),
681 "Einstein"
682 );
683 }
684
685 #[test]
686 fn test_work_to_vector() {
687 let client = CrossRefClient::new(None);
688
689 let work = CrossRefWork {
690 doi: "10.1234/example.2024".to_string(),
691 title: vec!["Deep Learning for Climate Science".to_string()],
692 abstract_text: Some("We propose a novel approach to climate modeling...".to_string()),
693 author: vec![
694 CrossRefAuthor {
695 given: Some("Alice".to_string()),
696 family: Some("Johnson".to_string()),
697 name: None,
698 orcid: Some("0000-0001-2345-6789".to_string()),
699 },
700 CrossRefAuthor {
701 given: Some("Bob".to_string()),
702 family: Some("Smith".to_string()),
703 name: None,
704 orcid: None,
705 },
706 ],
707 published_print: Some(CrossRefDate {
708 date_parts: vec![vec![2024, 6, 15]],
709 }),
710 published_online: None,
711 container_title: vec!["Nature Climate Change".to_string()],
712 citation_count: Some(42),
713 references_count: Some(35),
714 subject: vec!["Climate Science".to_string(), "Machine Learning".to_string()],
715 funder: vec![CrossRefFunder {
716 name: Some("National Science Foundation".to_string()),
717 doi: Some("10.13039/100000001".to_string()),
718 }],
719 work_type: Some("journal-article".to_string()),
720 publisher: Some("Nature Publishing Group".to_string()),
721 };
722
723 let vector = client.work_to_vector(work);
724
725 assert_eq!(vector.id, "doi:10.1234/example.2024");
726 assert_eq!(vector.domain, Domain::Research);
727 assert_eq!(
728 vector.metadata.get("doi").unwrap(),
729 "10.1234/example.2024"
730 );
731 assert_eq!(
732 vector.metadata.get("title").unwrap(),
733 "Deep Learning for Climate Science"
734 );
735 assert_eq!(
736 vector.metadata.get("authors").unwrap(),
737 "Alice Johnson; Bob Smith"
738 );
739 assert_eq!(
740 vector.metadata.get("journal").unwrap(),
741 "Nature Climate Change"
742 );
743 assert_eq!(vector.metadata.get("citation_count").unwrap(), "42");
744 assert_eq!(
745 vector.metadata.get("subjects").unwrap(),
746 "Climate Science, Machine Learning"
747 );
748 assert_eq!(
749 vector.metadata.get("funders").unwrap(),
750 "National Science Foundation"
751 );
752 assert_eq!(vector.metadata.get("type").unwrap(), "journal-article");
753 assert_eq!(
754 vector.metadata.get("publisher").unwrap(),
755 "Nature Publishing Group"
756 );
757 assert_eq!(vector.embedding.len(), DEFAULT_EMBEDDING_DIM);
758 }
759
760 #[tokio::test]
761 #[ignore] async fn test_search_works_integration() {
763 let client = CrossRefClient::new(Some("test@example.com".to_string()));
764 let results = client.search_works("machine learning", 5).await;
765 assert!(results.is_ok());
766
767 let vectors = results.unwrap();
768 assert!(vectors.len() <= 5);
769
770 if !vectors.is_empty() {
771 let first = &vectors[0];
772 assert!(first.id.starts_with("doi:"));
773 assert_eq!(first.domain, Domain::Research);
774 assert!(first.metadata.contains_key("title"));
775 assert!(first.metadata.contains_key("doi"));
776 }
777 }
778
779 #[tokio::test]
780 #[ignore] async fn test_get_work_integration() {
782 let client = CrossRefClient::new(Some("test@example.com".to_string()));
783
784 let result = client.get_work("10.1038/s41586-021-03819-2").await;
786 assert!(result.is_ok());
787
788 let work = result.unwrap();
789 assert!(work.is_some());
790
791 let vector = work.unwrap();
792 assert_eq!(vector.id, "doi:10.1038/s41586-021-03819-2");
793 assert_eq!(vector.domain, Domain::Research);
794 }
795
796 #[tokio::test]
797 #[ignore] async fn test_search_by_funder_integration() {
799 let client = CrossRefClient::new(Some("test@example.com".to_string()));
800
801 let results = client.search_by_funder("10.13039/100000001", 3).await;
803 assert!(results.is_ok());
804
805 let vectors = results.unwrap();
806 assert!(vectors.len() <= 3);
807 }
808
809 #[tokio::test]
810 #[ignore] async fn test_search_by_type_integration() {
812 let client = CrossRefClient::new(Some("test@example.com".to_string()));
813
814 let results = client.search_by_type("dataset", Some("climate"), 5).await;
816 assert!(results.is_ok());
817
818 let vectors = results.unwrap();
819 assert!(vectors.len() <= 5);
820 }
821
822 #[tokio::test]
823 #[ignore] async fn test_search_recent_integration() {
825 let client = CrossRefClient::new(Some("test@example.com".to_string()));
826
827 let results = client
829 .search_recent("quantum computing", "2024-01-01", 5)
830 .await;
831 assert!(results.is_ok());
832
833 let vectors = results.unwrap();
834 assert!(vectors.len() <= 5);
835 }
836}