ruvector_data_framework/
arxiv_client.rs

1//! ArXiv Preprint API Integration
2//!
3//! This module provides an async client for fetching academic preprints from ArXiv.org,
4//! converting responses to SemanticVector format for RuVector discovery.
5//!
6//! # ArXiv API Details
7//! - Base URL: https://export.arxiv.org/api/query
8//! - Free access, no authentication required
9//! - Returns Atom XML feed
10//! - Rate limit: 1 request per 3 seconds (enforced by client)
11//!
12//! # Example
13//! ```rust,ignore
14//! use ruvector_data_framework::arxiv_client::ArxivClient;
15//!
16//! let client = ArxivClient::new();
17//!
18//! // Search papers by keywords
19//! let vectors = client.search("machine learning", 10).await?;
20//!
21//! // Search by category
22//! let ai_papers = client.search_category("cs.AI", 20).await?;
23//!
24//! // Get recent papers in a category
25//! let recent = client.search_recent("cs.LG", 7).await?;
26//! ```
27
28use std::collections::HashMap;
29use std::time::Duration;
30
31use chrono::{DateTime, NaiveDateTime, Utc};
32use reqwest::{Client, StatusCode};
33use serde::Deserialize;
34use tokio::time::sleep;
35
36use crate::api_clients::SimpleEmbedder;
37use crate::ruvector_native::{Domain, SemanticVector};
38use crate::{FrameworkError, Result};
39
40/// Rate limiting configuration for ArXiv API
41const ARXIV_RATE_LIMIT_MS: u64 = 3000; // 3 seconds between requests
42const MAX_RETRIES: u32 = 3;
43const RETRY_DELAY_MS: u64 = 2000;
44const DEFAULT_EMBEDDING_DIM: usize = 384;
45
46// ============================================================================
47// ArXiv Atom Feed Structures
48// ============================================================================
49
50/// ArXiv API Atom feed response
51#[derive(Debug, Deserialize)]
52struct ArxivFeed {
53    #[serde(rename = "entry", default)]
54    entries: Vec<ArxivEntry>,
55    #[serde(rename = "totalResults", default)]
56    total_results: Option<TotalResults>,
57}
58
59#[derive(Debug, Deserialize)]
60struct TotalResults {
61    #[serde(rename = "$value", default)]
62    value: Option<String>,
63}
64
65/// ArXiv entry (paper)
66#[derive(Debug, Deserialize)]
67struct ArxivEntry {
68    #[serde(rename = "id")]
69    id: String,
70    #[serde(rename = "title")]
71    title: String,
72    #[serde(rename = "summary")]
73    summary: String,
74    #[serde(rename = "published")]
75    published: String,
76    #[serde(rename = "updated", default)]
77    updated: Option<String>,
78    #[serde(rename = "author", default)]
79    authors: Vec<ArxivAuthor>,
80    #[serde(rename = "category", default)]
81    categories: Vec<ArxivCategory>,
82    #[serde(rename = "link", default)]
83    links: Vec<ArxivLink>,
84}
85
86#[derive(Debug, Deserialize)]
87struct ArxivAuthor {
88    #[serde(rename = "name")]
89    name: String,
90}
91
92#[derive(Debug, Deserialize)]
93struct ArxivCategory {
94    #[serde(rename = "@term")]
95    term: String,
96}
97
98#[derive(Debug, Deserialize)]
99struct ArxivLink {
100    #[serde(rename = "@href")]
101    href: String,
102    #[serde(rename = "@type", default)]
103    link_type: Option<String>,
104    #[serde(rename = "@title", default)]
105    title: Option<String>,
106}
107
108// ============================================================================
109// ArXiv Client
110// ============================================================================
111
112/// Client for ArXiv.org preprint API
113///
114/// Provides methods to search for academic papers, filter by category,
115/// and convert results to SemanticVector format for RuVector analysis.
116///
117/// # Rate Limiting
118/// The client automatically enforces ArXiv's rate limit of 1 request per 3 seconds.
119/// Includes retry logic for transient failures.
120pub struct ArxivClient {
121    client: Client,
122    embedder: SimpleEmbedder,
123    base_url: String,
124}
125
126impl ArxivClient {
127    /// Create a new ArXiv API client
128    ///
129    /// # Example
130    /// ```rust,ignore
131    /// let client = ArxivClient::new();
132    /// ```
133    pub fn new() -> Self {
134        Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
135    }
136
137    /// Create a new ArXiv API client with custom embedding dimension
138    ///
139    /// # Arguments
140    /// * `embedding_dim` - Dimension for text embeddings (default: 384)
141    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
142        Self {
143            client: Client::builder()
144                .user_agent("RuVector-Discovery/1.0")
145                .timeout(Duration::from_secs(30))
146                .build()
147                .expect("Failed to create HTTP client"),
148            embedder: SimpleEmbedder::new(embedding_dim),
149            base_url: "https://export.arxiv.org/api/query".to_string(),
150        }
151    }
152
153    /// Search papers by keywords
154    ///
155    /// # Arguments
156    /// * `query` - Search query (keywords, title, author, etc.)
157    /// * `max_results` - Maximum number of results to return
158    ///
159    /// # Example
160    /// ```rust,ignore
161    /// let vectors = client.search("quantum computing", 50).await?;
162    /// ```
163    pub async fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticVector>> {
164        let encoded_query = urlencoding::encode(query);
165        let url = format!(
166            "{}?search_query=all:{}&start=0&max_results={}",
167            self.base_url, encoded_query, max_results
168        );
169
170        self.fetch_and_parse(&url).await
171    }
172
173    /// Search papers by ArXiv category
174    ///
175    /// # Arguments
176    /// * `category` - ArXiv category code (e.g., "cs.AI", "physics.ao-ph", "q-fin.ST")
177    /// * `max_results` - Maximum number of results to return
178    ///
179    /// # Supported Categories
180    /// - `cs.AI` - Artificial Intelligence
181    /// - `cs.LG` - Machine Learning
182    /// - `cs.CL` - Computation and Language
183    /// - `stat.ML` - Statistics - Machine Learning
184    /// - `q-fin.*` - Quantitative Finance (ST, PM, TR, etc.)
185    /// - `physics.ao-ph` - Atmospheric and Oceanic Physics
186    /// - `econ.*` - Economics
187    ///
188    /// # Example
189    /// ```rust,ignore
190    /// let ai_papers = client.search_category("cs.AI", 100).await?;
191    /// let climate_papers = client.search_category("physics.ao-ph", 50).await?;
192    /// ```
193    pub async fn search_category(
194        &self,
195        category: &str,
196        max_results: usize,
197    ) -> Result<Vec<SemanticVector>> {
198        let url = format!(
199            "{}?search_query=cat:{}&start=0&max_results={}&sortBy=submittedDate&sortOrder=descending",
200            self.base_url, category, max_results
201        );
202
203        self.fetch_and_parse(&url).await
204    }
205
206    /// Get a single paper by ArXiv ID
207    ///
208    /// # Arguments
209    /// * `arxiv_id` - ArXiv paper ID (e.g., "2401.12345" or "arXiv:2401.12345")
210    ///
211    /// # Example
212    /// ```rust,ignore
213    /// let paper = client.get_paper("2401.12345").await?;
214    /// ```
215    pub async fn get_paper(&self, arxiv_id: &str) -> Result<Option<SemanticVector>> {
216        // Strip "arXiv:" prefix if present
217        let id = arxiv_id.trim_start_matches("arXiv:");
218
219        let url = format!("{}?id_list={}", self.base_url, id);
220        let mut results = self.fetch_and_parse(&url).await?;
221
222        Ok(results.pop())
223    }
224
225    /// Search recent papers in a category within the last N days
226    ///
227    /// # Arguments
228    /// * `category` - ArXiv category code
229    /// * `days` - Number of days to look back (default: 7)
230    ///
231    /// # Example
232    /// ```rust,ignore
233    /// // Get ML papers from the last 3 days
234    /// let recent = client.search_recent("cs.LG", 3).await?;
235    /// ```
236    pub async fn search_recent(
237        &self,
238        category: &str,
239        days: u64,
240    ) -> Result<Vec<SemanticVector>> {
241        let cutoff_date = Utc::now() - chrono::Duration::days(days as i64);
242
243        let url = format!(
244            "{}?search_query=cat:{}&start=0&max_results=100&sortBy=submittedDate&sortOrder=descending",
245            self.base_url, category
246        );
247
248        let all_results = self.fetch_and_parse(&url).await?;
249
250        // Filter by date
251        Ok(all_results
252            .into_iter()
253            .filter(|v| v.timestamp >= cutoff_date)
254            .collect())
255    }
256
257    /// Search papers across multiple categories
258    ///
259    /// # Arguments
260    /// * `categories` - List of ArXiv category codes
261    /// * `max_results_per_category` - Maximum results per category
262    ///
263    /// # Example
264    /// ```rust,ignore
265    /// let categories = vec!["cs.AI", "cs.LG", "stat.ML"];
266    /// let papers = client.search_multiple_categories(&categories, 20).await?;
267    /// ```
268    pub async fn search_multiple_categories(
269        &self,
270        categories: &[&str],
271        max_results_per_category: usize,
272    ) -> Result<Vec<SemanticVector>> {
273        let mut all_vectors = Vec::new();
274
275        for category in categories {
276            match self.search_category(category, max_results_per_category).await {
277                Ok(mut vectors) => {
278                    all_vectors.append(&mut vectors);
279                }
280                Err(e) => {
281                    tracing::warn!("Failed to fetch category {}: {}", category, e);
282                }
283            }
284            // Rate limiting between categories
285            sleep(Duration::from_millis(ARXIV_RATE_LIMIT_MS)).await;
286        }
287
288        Ok(all_vectors)
289    }
290
291    /// Fetch and parse ArXiv Atom feed
292    async fn fetch_and_parse(&self, url: &str) -> Result<Vec<SemanticVector>> {
293        // Rate limiting
294        sleep(Duration::from_millis(ARXIV_RATE_LIMIT_MS)).await;
295
296        let response = self.fetch_with_retry(url).await?;
297        let xml = response.text().await?;
298
299        // Parse XML feed
300        let feed: ArxivFeed = quick_xml::de::from_str(&xml).map_err(|e| {
301            FrameworkError::Ingestion(format!("Failed to parse ArXiv XML: {}", e))
302        })?;
303
304        // Convert entries to SemanticVectors
305        let mut vectors = Vec::new();
306        for entry in feed.entries {
307            if let Some(vector) = self.entry_to_vector(entry) {
308                vectors.push(vector);
309            }
310        }
311
312        Ok(vectors)
313    }
314
315    /// Convert ArXiv entry to SemanticVector
316    fn entry_to_vector(&self, entry: ArxivEntry) -> Option<SemanticVector> {
317        // Extract ArXiv ID from full URL
318        let arxiv_id = entry
319            .id
320            .split('/')
321            .last()
322            .unwrap_or(&entry.id)
323            .to_string();
324
325        // Clean up title and abstract
326        let title = entry.title.trim().replace('\n', " ");
327        let abstract_text = entry.summary.trim().replace('\n', " ");
328
329        // Parse publication date
330        let timestamp = Self::parse_arxiv_date(&entry.published)?;
331
332        // Generate embedding from title + abstract
333        let combined_text = format!("{} {}", title, abstract_text);
334        let embedding = self.embedder.embed_text(&combined_text);
335
336        // Extract authors
337        let authors = entry
338            .authors
339            .iter()
340            .map(|a| a.name.clone())
341            .collect::<Vec<_>>()
342            .join(", ");
343
344        // Extract categories
345        let categories = entry
346            .categories
347            .iter()
348            .map(|c| c.term.clone())
349            .collect::<Vec<_>>()
350            .join(", ");
351
352        // Find PDF URL
353        let pdf_url = entry
354            .links
355            .iter()
356            .find(|l| l.title.as_deref() == Some("pdf"))
357            .map(|l| l.href.clone())
358            .unwrap_or_else(|| format!("https://arxiv.org/pdf/{}.pdf", arxiv_id));
359
360        // Build metadata
361        let mut metadata = HashMap::new();
362        metadata.insert("arxiv_id".to_string(), arxiv_id.clone());
363        metadata.insert("title".to_string(), title);
364        metadata.insert("abstract".to_string(), abstract_text);
365        metadata.insert("authors".to_string(), authors);
366        metadata.insert("categories".to_string(), categories);
367        metadata.insert("pdf_url".to_string(), pdf_url);
368        metadata.insert("source".to_string(), "arxiv".to_string());
369
370        Some(SemanticVector {
371            id: format!("arXiv:{}", arxiv_id),
372            embedding,
373            domain: Domain::Research,
374            timestamp,
375            metadata,
376        })
377    }
378
379    /// Parse ArXiv date format (ISO 8601)
380    fn parse_arxiv_date(date_str: &str) -> Option<DateTime<Utc>> {
381        // ArXiv uses ISO 8601 format: 2024-01-15T12:30:00Z
382        DateTime::parse_from_rfc3339(date_str)
383            .ok()
384            .map(|dt| dt.with_timezone(&Utc))
385            .or_else(|| {
386                // Fallback: try parsing without timezone
387                NaiveDateTime::parse_from_str(date_str, "%Y-%m-%dT%H:%M:%S")
388                    .ok()
389                    .map(|ndt| DateTime::from_naive_utc_and_offset(ndt, Utc))
390            })
391    }
392
393    /// Fetch with retry logic
394    async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
395        let mut retries = 0;
396        loop {
397            match self.client.get(url).send().await {
398                Ok(response) => {
399                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
400                    {
401                        retries += 1;
402                        tracing::warn!("Rate limited by ArXiv, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
403                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
404                        continue;
405                    }
406                    if !response.status().is_success() {
407                        return Err(FrameworkError::Network(
408                            reqwest::Error::from(response.error_for_status().unwrap_err()),
409                        ));
410                    }
411                    return Ok(response);
412                }
413                Err(_) if retries < MAX_RETRIES => {
414                    retries += 1;
415                    tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
416                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
417                }
418                Err(e) => return Err(FrameworkError::Network(e)),
419            }
420        }
421    }
422}
423
424impl Default for ArxivClient {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430// ============================================================================
431// Tests
432// ============================================================================
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_arxiv_client_creation() {
440        let client = ArxivClient::new();
441        assert_eq!(client.base_url, "https://export.arxiv.org/api/query");
442    }
443
444    #[test]
445    fn test_custom_embedding_dim() {
446        let client = ArxivClient::with_embedding_dim(512);
447        let embedding = client.embedder.embed_text("test");
448        assert_eq!(embedding.len(), 512);
449    }
450
451    #[test]
452    fn test_parse_arxiv_date() {
453        // Standard ISO 8601
454        let date1 = ArxivClient::parse_arxiv_date("2024-01-15T12:30:00Z");
455        assert!(date1.is_some());
456
457        // Without Z suffix
458        let date2 = ArxivClient::parse_arxiv_date("2024-01-15T12:30:00");
459        assert!(date2.is_some());
460    }
461
462    #[test]
463    fn test_entry_to_vector() {
464        let client = ArxivClient::new();
465
466        let entry = ArxivEntry {
467            id: "http://arxiv.org/abs/2401.12345v1".to_string(),
468            title: "Deep Learning for Climate Science".to_string(),
469            summary: "We propose a novel approach...".to_string(),
470            published: "2024-01-15T12:00:00Z".to_string(),
471            updated: None,
472            authors: vec![
473                ArxivAuthor {
474                    name: "John Doe".to_string(),
475                },
476                ArxivAuthor {
477                    name: "Jane Smith".to_string(),
478                },
479            ],
480            categories: vec![
481                ArxivCategory {
482                    term: "cs.LG".to_string(),
483                },
484                ArxivCategory {
485                    term: "physics.ao-ph".to_string(),
486                },
487            ],
488            links: vec![],
489        };
490
491        let vector = client.entry_to_vector(entry);
492        assert!(vector.is_some());
493
494        let v = vector.unwrap();
495        assert_eq!(v.id, "arXiv:2401.12345v1");
496        assert_eq!(v.domain, Domain::Research);
497        assert_eq!(v.metadata.get("arxiv_id").unwrap(), "2401.12345v1");
498        assert_eq!(
499            v.metadata.get("title").unwrap(),
500            "Deep Learning for Climate Science"
501        );
502        assert_eq!(v.metadata.get("authors").unwrap(), "John Doe, Jane Smith");
503        assert_eq!(v.metadata.get("categories").unwrap(), "cs.LG, physics.ao-ph");
504    }
505
506    #[tokio::test]
507    #[ignore] // Ignore by default to avoid hitting ArXiv API in tests
508    async fn test_search_integration() {
509        let client = ArxivClient::new();
510        let results = client.search("machine learning", 5).await;
511        assert!(results.is_ok());
512
513        let vectors = results.unwrap();
514        assert!(vectors.len() <= 5);
515
516        if !vectors.is_empty() {
517            let first = &vectors[0];
518            assert!(first.id.starts_with("arXiv:"));
519            assert_eq!(first.domain, Domain::Research);
520            assert!(first.metadata.contains_key("title"));
521            assert!(first.metadata.contains_key("abstract"));
522        }
523    }
524
525    #[tokio::test]
526    #[ignore] // Ignore by default to avoid hitting ArXiv API in tests
527    async fn test_search_category_integration() {
528        let client = ArxivClient::new();
529        let results = client.search_category("cs.AI", 3).await;
530        assert!(results.is_ok());
531
532        let vectors = results.unwrap();
533        assert!(vectors.len() <= 3);
534    }
535
536    #[tokio::test]
537    #[ignore] // Ignore by default to avoid hitting ArXiv API in tests
538    async fn test_get_paper_integration() {
539        let client = ArxivClient::new();
540
541        // Try to fetch a known paper (this is a real arXiv ID)
542        let result = client.get_paper("2301.00001").await;
543        assert!(result.is_ok());
544    }
545
546    #[tokio::test]
547    #[ignore] // Ignore by default to avoid hitting ArXiv API in tests
548    async fn test_search_recent_integration() {
549        let client = ArxivClient::new();
550        let results = client.search_recent("cs.LG", 7).await;
551        assert!(results.is_ok());
552
553        // Check that returned papers are within date range
554        let cutoff = Utc::now() - chrono::Duration::days(7);
555        for vector in results.unwrap() {
556            assert!(vector.timestamp >= cutoff);
557        }
558    }
559
560    #[tokio::test]
561    #[ignore] // Ignore by default to avoid hitting ArXiv API in tests
562    async fn test_multiple_categories_integration() {
563        let client = ArxivClient::new();
564        let categories = vec!["cs.AI", "cs.LG"];
565        let results = client.search_multiple_categories(&categories, 2).await;
566        assert!(results.is_ok());
567
568        let vectors = results.unwrap();
569        assert!(vectors.len() <= 4); // 2 categories * 2 results each
570    }
571}