ruvector_data_framework/
ml_clients.rs

1//! AI/ML API Client Integrations
2//!
3//! This module provides async clients for AI/ML platforms including:
4//! - HuggingFace: Model hub and inference
5//! - Ollama: Local LLM inference
6//! - Replicate: Cloud ML models
7//! - TogetherAI: Open source model hosting
8//! - Papers With Code: ML research papers and benchmarks
9//!
10//! All clients follow the framework's patterns with rate limiting, mock fallbacks,
11//! and conversion to SemanticVector format for RuVector discovery.
12
13use std::collections::HashMap;
14use std::env;
15use std::time::Duration;
16
17use chrono::{DateTime, Utc};
18use reqwest::{Client, StatusCode};
19use serde::{Deserialize, Serialize};
20use tokio::time::sleep;
21
22use crate::api_clients::SimpleEmbedder;
23use crate::ruvector_native::{Domain, SemanticVector};
24use crate::{FrameworkError, Result};
25
26/// Rate limiting configuration for different services
27const HUGGINGFACE_RATE_LIMIT_MS: u64 = 2000; // 30 req/min = 2000ms
28const PAPERWITHCODE_RATE_LIMIT_MS: u64 = 1000; // 60 req/min = 1000ms
29const REPLICATE_RATE_LIMIT_MS: u64 = 1000;
30const TOGETHER_RATE_LIMIT_MS: u64 = 1000;
31const OLLAMA_RATE_LIMIT_MS: u64 = 100; // Local, minimal delay
32
33const MAX_RETRIES: u32 = 3;
34const RETRY_DELAY_MS: u64 = 2000;
35const DEFAULT_EMBEDDING_DIM: usize = 384;
36const REQUEST_TIMEOUT_SECS: u64 = 30;
37
38// ============================================================================
39// HuggingFace Client
40// ============================================================================
41
42/// HuggingFace model information
43#[derive(Debug, Clone, Deserialize, Serialize)]
44pub struct HuggingFaceModel {
45    #[serde(rename = "modelId")]
46    pub model_id: String,
47    #[serde(rename = "author")]
48    pub author: Option<String>,
49    #[serde(rename = "downloads")]
50    pub downloads: Option<u64>,
51    #[serde(rename = "likes")]
52    pub likes: Option<u64>,
53    #[serde(rename = "tags")]
54    pub tags: Option<Vec<String>>,
55    #[serde(rename = "pipeline_tag")]
56    pub pipeline_tag: Option<String>,
57    #[serde(rename = "createdAt")]
58    pub created_at: Option<String>,
59}
60
61/// HuggingFace dataset information
62#[derive(Debug, Clone, Deserialize, Serialize)]
63pub struct HuggingFaceDataset {
64    pub id: String,
65    pub author: Option<String>,
66    pub downloads: Option<u64>,
67    pub likes: Option<u64>,
68    pub tags: Option<Vec<String>>,
69    #[serde(rename = "createdAt")]
70    pub created_at: Option<String>,
71    pub description: Option<String>,
72}
73
74/// HuggingFace inference input
75#[derive(Debug, Clone, Serialize)]
76pub struct HuggingFaceInferenceInput {
77    pub inputs: serde_json::Value,
78}
79
80/// HuggingFace inference response
81#[derive(Debug, Clone, Deserialize)]
82#[serde(untagged)]
83pub enum HuggingFaceInferenceResponse {
84    Embeddings(Vec<Vec<f32>>),
85    Classification(Vec<ClassificationResult>),
86    Generation(Vec<GenerationResult>),
87    Error(InferenceError),
88}
89
90#[derive(Debug, Clone, Deserialize)]
91pub struct ClassificationResult {
92    pub label: String,
93    pub score: f64,
94}
95
96#[derive(Debug, Clone, Deserialize)]
97pub struct GenerationResult {
98    pub generated_text: String,
99}
100
101#[derive(Debug, Clone, Deserialize)]
102pub struct InferenceError {
103    pub error: String,
104}
105
106/// Client for HuggingFace model hub and inference API
107///
108/// # API Details
109/// - Base URL: https://huggingface.co/api
110/// - Rate limit: 30 requests/minute (free tier)
111/// - API key optional for public models
112///
113/// # Environment Variables
114/// - `HUGGINGFACE_API_KEY`: Optional API key for higher rate limits and private models
115pub struct HuggingFaceClient {
116    client: Client,
117    embedder: SimpleEmbedder,
118    base_url: String,
119    api_key: Option<String>,
120    use_mock: bool,
121}
122
123impl HuggingFaceClient {
124    /// Create a new HuggingFace client
125    ///
126    /// Reads API key from `HUGGINGFACE_API_KEY` environment variable if available.
127    /// Falls back to mock data if no API key is provided.
128    pub fn new() -> Self {
129        Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
130    }
131
132    /// Create a new HuggingFace client with custom embedding dimension
133    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
134        let api_key = env::var("HUGGINGFACE_API_KEY").ok();
135        let use_mock = api_key.is_none();
136
137        if use_mock {
138            tracing::warn!("HUGGINGFACE_API_KEY not set, using mock data");
139        }
140
141        Self {
142            client: Client::builder()
143                .user_agent("RuVector-Discovery/1.0")
144                .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
145                .build()
146                .expect("Failed to create HTTP client"),
147            embedder: SimpleEmbedder::new(embedding_dim),
148            base_url: "https://huggingface.co/api".to_string(),
149            api_key,
150            use_mock,
151        }
152    }
153
154    /// Search models by query and optional task filter
155    ///
156    /// # Arguments
157    /// * `query` - Search query string
158    /// * `task` - Optional task filter (e.g., "text-classification", "text-generation")
159    ///
160    /// # Example
161    /// ```rust,ignore
162    /// let models = client.search_models("bert", Some("fill-mask")).await?;
163    /// ```
164    pub async fn search_models(
165        &self,
166        query: &str,
167        task: Option<&str>,
168    ) -> Result<Vec<HuggingFaceModel>> {
169        if self.use_mock {
170            return Ok(self.mock_models(query));
171        }
172
173        sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
174
175        let mut url = format!("{}/models?search={}", self.base_url, urlencoding::encode(query));
176        if let Some(task_filter) = task {
177            url.push_str(&format!("&filter={}", task_filter));
178        }
179        url.push_str("&limit=20");
180
181        let response = self.fetch_with_retry(&url).await?;
182        let models: Vec<HuggingFaceModel> = response.json().await?;
183
184        Ok(models)
185    }
186
187    /// Get detailed information about a specific model
188    ///
189    /// # Arguments
190    /// * `model_id` - Model identifier (e.g., "bert-base-uncased")
191    pub async fn get_model(&self, model_id: &str) -> Result<Option<HuggingFaceModel>> {
192        if self.use_mock {
193            return Ok(self.mock_models(model_id).into_iter().next());
194        }
195
196        sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
197
198        let url = format!("{}/models/{}", self.base_url, model_id);
199        let response = self.fetch_with_retry(&url).await?;
200        let model: HuggingFaceModel = response.json().await?;
201
202        Ok(Some(model))
203    }
204
205    /// List datasets with optional query filter
206    ///
207    /// # Arguments
208    /// * `query` - Optional search query for datasets
209    pub async fn list_datasets(&self, query: Option<&str>) -> Result<Vec<HuggingFaceDataset>> {
210        if self.use_mock {
211            return Ok(self.mock_datasets(query.unwrap_or("ml")));
212        }
213
214        sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
215
216        let mut url = format!("{}/datasets", self.base_url);
217        if let Some(q) = query {
218            url.push_str(&format!("?search={}", urlencoding::encode(q)));
219        }
220        url.push_str("&limit=20");
221
222        let response = self.fetch_with_retry(&url).await?;
223        let datasets: Vec<HuggingFaceDataset> = response.json().await?;
224
225        Ok(datasets)
226    }
227
228    /// Get detailed information about a specific dataset
229    ///
230    /// # Arguments
231    /// * `dataset_id` - Dataset identifier
232    pub async fn get_dataset(&self, dataset_id: &str) -> Result<Option<HuggingFaceDataset>> {
233        if self.use_mock {
234            return Ok(self.mock_datasets(dataset_id).into_iter().next());
235        }
236
237        sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
238
239        let url = format!("{}/datasets/{}", self.base_url, dataset_id);
240        let response = self.fetch_with_retry(&url).await?;
241        let dataset: HuggingFaceDataset = response.json().await?;
242
243        Ok(Some(dataset))
244    }
245
246    /// Run inference on a model
247    ///
248    /// # Arguments
249    /// * `model_id` - Model identifier
250    /// * `inputs` - Input data as JSON value
251    ///
252    /// # Note
253    /// Requires API key. Returns mock embeddings if no API key is available.
254    pub async fn inference(
255        &self,
256        model_id: &str,
257        inputs: serde_json::Value,
258    ) -> Result<HuggingFaceInferenceResponse> {
259        if self.use_mock {
260            // Return mock embeddings
261            let embedding = self.embedder.embed_json(&inputs);
262            return Ok(HuggingFaceInferenceResponse::Embeddings(vec![embedding]));
263        }
264
265        sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
266
267        let url = format!("https://api-inference.huggingface.co/models/{}", model_id);
268        let body = HuggingFaceInferenceInput { inputs };
269
270        let mut request = self.client.post(&url).json(&body);
271
272        if let Some(key) = &self.api_key {
273            request = request.header("Authorization", format!("Bearer {}", key));
274        }
275
276        let response = request.send().await?;
277
278        if !response.status().is_success() {
279            return Err(FrameworkError::Network(
280                reqwest::Error::from(response.error_for_status().unwrap_err()),
281            ));
282        }
283
284        let result: HuggingFaceInferenceResponse = response.json().await?;
285        Ok(result)
286    }
287
288    /// Convert HuggingFace model to SemanticVector
289    pub fn model_to_vector(&self, model: &HuggingFaceModel) -> SemanticVector {
290        let text = format!(
291            "{} {} {}",
292            model.model_id,
293            model.pipeline_tag.as_deref().unwrap_or(""),
294            model.tags.as_ref().map(|t| t.join(" ")).unwrap_or_default()
295        );
296
297        let embedding = self.embedder.embed_text(&text);
298
299        let mut metadata = HashMap::new();
300        metadata.insert("model_id".to_string(), model.model_id.clone());
301        if let Some(author) = &model.author {
302            metadata.insert("author".to_string(), author.clone());
303        }
304        if let Some(downloads) = model.downloads {
305            metadata.insert("downloads".to_string(), downloads.to_string());
306        }
307        if let Some(likes) = model.likes {
308            metadata.insert("likes".to_string(), likes.to_string());
309        }
310        if let Some(pipeline) = &model.pipeline_tag {
311            metadata.insert("task".to_string(), pipeline.clone());
312        }
313        metadata.insert("source".to_string(), "huggingface".to_string());
314
315        let timestamp = model
316            .created_at
317            .as_ref()
318            .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
319            .map(|dt| dt.with_timezone(&Utc))
320            .unwrap_or_else(Utc::now);
321
322        SemanticVector {
323            id: format!("hf:model:{}", model.model_id),
324            embedding,
325            domain: Domain::Research,
326            timestamp,
327            metadata,
328        }
329    }
330
331    /// Convert HuggingFace dataset to SemanticVector
332    pub fn dataset_to_vector(&self, dataset: &HuggingFaceDataset) -> SemanticVector {
333        let text = format!(
334            "{} {}",
335            dataset.id,
336            dataset.description.as_deref().unwrap_or("")
337        );
338
339        let embedding = self.embedder.embed_text(&text);
340
341        let mut metadata = HashMap::new();
342        metadata.insert("dataset_id".to_string(), dataset.id.clone());
343        if let Some(author) = &dataset.author {
344            metadata.insert("author".to_string(), author.clone());
345        }
346        if let Some(downloads) = dataset.downloads {
347            metadata.insert("downloads".to_string(), downloads.to_string());
348        }
349        metadata.insert("source".to_string(), "huggingface".to_string());
350
351        let timestamp = dataset
352            .created_at
353            .as_ref()
354            .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
355            .map(|dt| dt.with_timezone(&Utc))
356            .unwrap_or_else(Utc::now);
357
358        SemanticVector {
359            id: format!("hf:dataset:{}", dataset.id),
360            embedding,
361            domain: Domain::Research,
362            timestamp,
363            metadata,
364        }
365    }
366
367    /// Mock models for testing without API key
368    fn mock_models(&self, query: &str) -> Vec<HuggingFaceModel> {
369        vec![
370            HuggingFaceModel {
371                model_id: format!("bert-base-{}", query),
372                author: Some("google".to_string()),
373                downloads: Some(1_000_000),
374                likes: Some(500),
375                tags: Some(vec!["nlp".to_string(), "transformer".to_string()]),
376                pipeline_tag: Some("fill-mask".to_string()),
377                created_at: Some(Utc::now().to_rfc3339()),
378            },
379            HuggingFaceModel {
380                model_id: format!("gpt2-{}", query),
381                author: Some("openai".to_string()),
382                downloads: Some(800_000),
383                likes: Some(350),
384                tags: Some(vec!["text-generation".to_string()]),
385                pipeline_tag: Some("text-generation".to_string()),
386                created_at: Some(Utc::now().to_rfc3339()),
387            },
388        ]
389    }
390
391    /// Mock datasets for testing without API key
392    fn mock_datasets(&self, query: &str) -> Vec<HuggingFaceDataset> {
393        vec![
394            HuggingFaceDataset {
395                id: format!("squad-{}", query),
396                author: Some("datasets".to_string()),
397                downloads: Some(500_000),
398                likes: Some(200),
399                tags: Some(vec!["qa".to_string(), "english".to_string()]),
400                created_at: Some(Utc::now().to_rfc3339()),
401                description: Some("Question answering dataset".to_string()),
402            },
403            HuggingFaceDataset {
404                id: format!("glue-{}", query),
405                author: Some("datasets".to_string()),
406                downloads: Some(300_000),
407                likes: Some(150),
408                tags: Some(vec!["benchmark".to_string()]),
409                created_at: Some(Utc::now().to_rfc3339()),
410                description: Some("General Language Understanding Evaluation".to_string()),
411            },
412        ]
413    }
414
415    async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
416        let mut retries = 0;
417        loop {
418            let mut request = self.client.get(url);
419
420            if let Some(key) = &self.api_key {
421                request = request.header("Authorization", format!("Bearer {}", key));
422            }
423
424            match request.send().await {
425                Ok(response) => {
426                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
427                        retries += 1;
428                        tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
429                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
430                        continue;
431                    }
432                    if !response.status().is_success() {
433                        return Err(FrameworkError::Network(
434                            reqwest::Error::from(response.error_for_status().unwrap_err()),
435                        ));
436                    }
437                    return Ok(response);
438                }
439                Err(_) if retries < MAX_RETRIES => {
440                    retries += 1;
441                    tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
442                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
443                }
444                Err(e) => return Err(FrameworkError::Network(e)),
445            }
446        }
447    }
448}
449
450impl Default for HuggingFaceClient {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456// ============================================================================
457// Ollama Client (Local LLM)
458// ============================================================================
459
460/// Ollama model information
461#[derive(Debug, Clone, Deserialize, Serialize)]
462pub struct OllamaModel {
463    pub name: String,
464    pub modified_at: Option<String>,
465    pub size: Option<u64>,
466    pub digest: Option<String>,
467}
468
469/// Ollama model list response
470#[derive(Debug, Clone, Deserialize)]
471pub struct OllamaModelsResponse {
472    pub models: Vec<OllamaModel>,
473}
474
475/// Ollama generation request
476#[derive(Debug, Clone, Serialize)]
477pub struct OllamaGenerateRequest {
478    pub model: String,
479    pub prompt: String,
480    pub stream: bool,
481}
482
483/// Ollama generation response
484#[derive(Debug, Clone, Deserialize)]
485pub struct OllamaGenerateResponse {
486    pub model: String,
487    pub response: String,
488    pub done: bool,
489}
490
491/// Ollama chat message
492#[derive(Debug, Clone, Serialize)]
493pub struct OllamaChatMessage {
494    pub role: String,
495    pub content: String,
496}
497
498/// Ollama chat request
499#[derive(Debug, Clone, Serialize)]
500pub struct OllamaChatRequest {
501    pub model: String,
502    pub messages: Vec<OllamaChatMessage>,
503    pub stream: bool,
504}
505
506/// Ollama chat response
507#[derive(Debug, Clone, Deserialize)]
508pub struct OllamaChatResponse {
509    pub model: String,
510    pub message: OllamaMessage,
511    pub done: bool,
512}
513
514#[derive(Debug, Clone, Deserialize)]
515pub struct OllamaMessage {
516    pub role: String,
517    pub content: String,
518}
519
520/// Ollama embeddings request
521#[derive(Debug, Clone, Serialize)]
522pub struct OllamaEmbeddingsRequest {
523    pub model: String,
524    pub prompt: String,
525}
526
527/// Ollama embeddings response
528#[derive(Debug, Clone, Deserialize)]
529pub struct OllamaEmbeddingsResponse {
530    pub embedding: Vec<f32>,
531}
532
533/// Client for Ollama local LLM inference
534///
535/// # API Details
536/// - Base URL: http://localhost:11434/api (default)
537/// - No rate limit (local service)
538/// - No API key required
539/// - Falls back to mock data when Ollama is not running
540pub struct OllamaClient {
541    client: Client,
542    embedder: SimpleEmbedder,
543    base_url: String,
544    use_mock: bool,
545}
546
547impl OllamaClient {
548    /// Create a new Ollama client with default base URL
549    pub fn new() -> Self {
550        Self::with_base_url("http://localhost:11434/api")
551    }
552
553    /// Create a new Ollama client with custom base URL
554    ///
555    /// # Arguments
556    /// * `base_url` - Ollama API base URL (e.g., "http://localhost:11434/api")
557    pub fn with_base_url(base_url: &str) -> Self {
558        Self {
559            client: Client::builder()
560                .user_agent("RuVector-Discovery/1.0")
561                .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
562                .build()
563                .expect("Failed to create HTTP client"),
564            embedder: SimpleEmbedder::new(DEFAULT_EMBEDDING_DIM),
565            base_url: base_url.to_string(),
566            use_mock: false,
567        }
568    }
569
570    /// Check if Ollama is available
571    pub async fn is_available(&self) -> bool {
572        self.client
573            .get(&format!("{}/tags", self.base_url))
574            .send()
575            .await
576            .map(|r| r.status().is_success())
577            .unwrap_or(false)
578    }
579
580    /// List available models
581    pub async fn list_models(&mut self) -> Result<Vec<OllamaModel>> {
582        sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
583
584        let url = format!("{}/tags", self.base_url);
585
586        match self.client.get(&url).send().await {
587            Ok(response) if response.status().is_success() => {
588                let data: OllamaModelsResponse = response.json().await?;
589                self.use_mock = false;
590                Ok(data.models)
591            }
592            _ => {
593                if !self.use_mock {
594                    tracing::warn!("Ollama not available, using mock data");
595                    self.use_mock = true;
596                }
597                Ok(self.mock_models())
598            }
599        }
600    }
601
602    /// Generate text completion
603    ///
604    /// # Arguments
605    /// * `model` - Model name (e.g., "llama2", "mistral")
606    /// * `prompt` - Prompt text
607    pub async fn generate(&mut self, model: &str, prompt: &str) -> Result<String> {
608        if self.use_mock || !self.is_available().await {
609            self.use_mock = true;
610            return Ok(self.mock_generation(prompt));
611        }
612
613        sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
614
615        let url = format!("{}/generate", self.base_url);
616        let body = OllamaGenerateRequest {
617            model: model.to_string(),
618            prompt: prompt.to_string(),
619            stream: false,
620        };
621
622        let response = self.client.post(&url).json(&body).send().await?;
623
624        if !response.status().is_success() {
625            return Err(FrameworkError::Network(
626                reqwest::Error::from(response.error_for_status().unwrap_err()),
627            ));
628        }
629
630        let result: OllamaGenerateResponse = response.json().await?;
631        Ok(result.response)
632    }
633
634    /// Chat completion with message history
635    ///
636    /// # Arguments
637    /// * `model` - Model name
638    /// * `messages` - Chat message history
639    pub async fn chat(
640        &mut self,
641        model: &str,
642        messages: Vec<OllamaChatMessage>,
643    ) -> Result<String> {
644        if self.use_mock || !self.is_available().await {
645            self.use_mock = true;
646            let last_msg = messages.last().map(|m| m.content.as_str()).unwrap_or("");
647            return Ok(self.mock_generation(last_msg));
648        }
649
650        sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
651
652        let url = format!("{}/chat", self.base_url);
653        let body = OllamaChatRequest {
654            model: model.to_string(),
655            messages,
656            stream: false,
657        };
658
659        let response = self.client.post(&url).json(&body).send().await?;
660
661        if !response.status().is_success() {
662            return Err(FrameworkError::Network(
663                reqwest::Error::from(response.error_for_status().unwrap_err()),
664            ));
665        }
666
667        let result: OllamaChatResponse = response.json().await?;
668        Ok(result.message.content)
669    }
670
671    /// Generate embeddings for text
672    ///
673    /// # Arguments
674    /// * `model` - Model name (e.g., "llama2")
675    /// * `prompt` - Text to embed
676    pub async fn embeddings(&mut self, model: &str, prompt: &str) -> Result<Vec<f32>> {
677        if self.use_mock || !self.is_available().await {
678            self.use_mock = true;
679            return Ok(self.embedder.embed_text(prompt));
680        }
681
682        sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
683
684        let url = format!("{}/embeddings", self.base_url);
685        let body = OllamaEmbeddingsRequest {
686            model: model.to_string(),
687            prompt: prompt.to_string(),
688        };
689
690        let response = self.client.post(&url).json(&body).send().await?;
691
692        if !response.status().is_success() {
693            return Err(FrameworkError::Network(
694                reqwest::Error::from(response.error_for_status().unwrap_err()),
695            ));
696        }
697
698        let result: OllamaEmbeddingsResponse = response.json().await?;
699        Ok(result.embedding)
700    }
701
702    /// Pull a model from Ollama library
703    ///
704    /// # Arguments
705    /// * `name` - Model name to pull
706    ///
707    /// # Note
708    /// This is a blocking operation that may take several minutes
709    pub async fn pull_model(&mut self, name: &str) -> Result<bool> {
710        if self.use_mock || !self.is_available().await {
711            self.use_mock = true;
712            tracing::warn!("Ollama not available, cannot pull model");
713            return Ok(false);
714        }
715
716        sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
717
718        let url = format!("{}/pull", self.base_url);
719        let body = serde_json::json!({ "name": name });
720
721        let response = self.client.post(&url).json(&body).send().await?;
722        Ok(response.status().is_success())
723    }
724
725    /// Convert Ollama model to SemanticVector
726    pub fn model_to_vector(&self, model: &OllamaModel) -> SemanticVector {
727        let embedding = self.embedder.embed_text(&model.name);
728
729        let mut metadata = HashMap::new();
730        metadata.insert("model_name".to_string(), model.name.clone());
731        if let Some(size) = model.size {
732            metadata.insert("size_bytes".to_string(), size.to_string());
733        }
734        if let Some(digest) = &model.digest {
735            metadata.insert("digest".to_string(), digest.clone());
736        }
737        metadata.insert("source".to_string(), "ollama".to_string());
738
739        let timestamp = model
740            .modified_at
741            .as_ref()
742            .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
743            .map(|dt| dt.with_timezone(&Utc))
744            .unwrap_or_else(Utc::now);
745
746        SemanticVector {
747            id: format!("ollama:model:{}", model.name),
748            embedding,
749            domain: Domain::Research,
750            timestamp,
751            metadata,
752        }
753    }
754
755    fn mock_models(&self) -> Vec<OllamaModel> {
756        vec![
757            OllamaModel {
758                name: "llama2:latest".to_string(),
759                modified_at: Some(Utc::now().to_rfc3339()),
760                size: Some(3_800_000_000),
761                digest: Some("sha256:mock123".to_string()),
762            },
763            OllamaModel {
764                name: "mistral:latest".to_string(),
765                modified_at: Some(Utc::now().to_rfc3339()),
766                size: Some(4_100_000_000),
767                digest: Some("sha256:mock456".to_string()),
768            },
769        ]
770    }
771
772    fn mock_generation(&self, prompt: &str) -> String {
773        format!("Mock response to: {}", prompt.chars().take(50).collect::<String>())
774    }
775}
776
777impl Default for OllamaClient {
778    fn default() -> Self {
779        Self::new()
780    }
781}
782
783// ============================================================================
784// Replicate Client
785// ============================================================================
786
787/// Replicate model information
788#[derive(Debug, Clone, Deserialize, Serialize)]
789pub struct ReplicateModel {
790    pub owner: String,
791    pub name: String,
792    pub description: Option<String>,
793    pub visibility: Option<String>,
794    pub github_url: Option<String>,
795    pub paper_url: Option<String>,
796    pub latest_version: Option<ReplicateVersion>,
797}
798
799/// Replicate model version
800#[derive(Debug, Clone, Deserialize, Serialize)]
801pub struct ReplicateVersion {
802    pub id: String,
803    pub created_at: Option<String>,
804}
805
806/// Replicate prediction request
807#[derive(Debug, Clone, Serialize)]
808pub struct ReplicatePredictionRequest {
809    pub version: String,
810    pub input: serde_json::Value,
811}
812
813/// Replicate prediction response
814#[derive(Debug, Clone, Deserialize)]
815pub struct ReplicatePrediction {
816    pub id: String,
817    pub status: String,
818    pub output: Option<serde_json::Value>,
819    pub error: Option<String>,
820}
821
822/// Replicate collection
823#[derive(Debug, Clone, Deserialize)]
824pub struct ReplicateCollection {
825    pub name: String,
826    pub slug: String,
827    pub description: Option<String>,
828}
829
830/// Client for Replicate cloud ML model API
831///
832/// # API Details
833/// - Base URL: https://api.replicate.com/v1
834/// - Requires API key
835/// - Falls back to mock data when no API key is available
836///
837/// # Environment Variables
838/// - `REPLICATE_API_TOKEN`: Required API token
839pub struct ReplicateClient {
840    client: Client,
841    embedder: SimpleEmbedder,
842    base_url: String,
843    api_token: Option<String>,
844    use_mock: bool,
845}
846
847impl ReplicateClient {
848    /// Create a new Replicate client
849    ///
850    /// Reads API token from `REPLICATE_API_TOKEN` environment variable.
851    pub fn new() -> Self {
852        Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
853    }
854
855    /// Create a new Replicate client with custom embedding dimension
856    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
857        let api_token = env::var("REPLICATE_API_TOKEN").ok();
858        let use_mock = api_token.is_none();
859
860        if use_mock {
861            tracing::warn!("REPLICATE_API_TOKEN not set, using mock data");
862        }
863
864        Self {
865            client: Client::builder()
866                .user_agent("RuVector-Discovery/1.0")
867                .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
868                .build()
869                .expect("Failed to create HTTP client"),
870            embedder: SimpleEmbedder::new(embedding_dim),
871            base_url: "https://api.replicate.com/v1".to_string(),
872            api_token,
873            use_mock,
874        }
875    }
876
877    /// Get model information
878    ///
879    /// # Arguments
880    /// * `owner` - Model owner username
881    /// * `name` - Model name
882    pub async fn get_model(&self, owner: &str, name: &str) -> Result<Option<ReplicateModel>> {
883        if self.use_mock {
884            return Ok(Some(self.mock_model(owner, name)));
885        }
886
887        sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
888
889        let url = format!("{}/models/{}/{}", self.base_url, owner, name);
890        let response = self.fetch_with_retry(&url).await?;
891        let model: ReplicateModel = response.json().await?;
892
893        Ok(Some(model))
894    }
895
896    /// Create a prediction (run a model)
897    ///
898    /// # Arguments
899    /// * `model` - Model identifier in "owner/name" format
900    /// * `input` - Input parameters as JSON
901    pub async fn create_prediction(
902        &self,
903        model: &str,
904        input: serde_json::Value,
905    ) -> Result<ReplicatePrediction> {
906        if self.use_mock {
907            return Ok(self.mock_prediction());
908        }
909
910        sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
911
912        let url = format!("{}/predictions", self.base_url);
913
914        // Get latest version for the model
915        let parts: Vec<&str> = model.split('/').collect();
916        if parts.len() != 2 {
917            return Err(FrameworkError::Config(
918                "Model must be in 'owner/name' format".to_string(),
919            ));
920        }
921
922        let model_info = self.get_model(parts[0], parts[1]).await?;
923        let version = model_info
924            .and_then(|m| m.latest_version)
925            .and_then(|v| Some(v.id))
926            .ok_or_else(|| FrameworkError::Config("Model version not found".to_string()))?;
927
928        let body = ReplicatePredictionRequest { version, input };
929
930        let response = self.fetch_with_retry_post(&url, &body).await?;
931        let prediction: ReplicatePrediction = response.json().await?;
932
933        Ok(prediction)
934    }
935
936    /// Get prediction status and output
937    ///
938    /// # Arguments
939    /// * `id` - Prediction ID
940    pub async fn get_prediction(&self, id: &str) -> Result<ReplicatePrediction> {
941        if self.use_mock {
942            return Ok(self.mock_prediction());
943        }
944
945        sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
946
947        let url = format!("{}/predictions/{}", self.base_url, id);
948        let response = self.fetch_with_retry(&url).await?;
949        let prediction: ReplicatePrediction = response.json().await?;
950
951        Ok(prediction)
952    }
953
954    /// List model collections
955    pub async fn list_collections(&self) -> Result<Vec<ReplicateCollection>> {
956        if self.use_mock {
957            return Ok(self.mock_collections());
958        }
959
960        sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
961
962        let url = format!("{}/collections", self.base_url);
963        let response = self.fetch_with_retry(&url).await?;
964        let collections: Vec<ReplicateCollection> = response.json().await?;
965
966        Ok(collections)
967    }
968
969    /// Convert Replicate model to SemanticVector
970    pub fn model_to_vector(&self, model: &ReplicateModel) -> SemanticVector {
971        let text = format!(
972            "{}/{} {}",
973            model.owner,
974            model.name,
975            model.description.as_deref().unwrap_or("")
976        );
977
978        let embedding = self.embedder.embed_text(&text);
979
980        let mut metadata = HashMap::new();
981        metadata.insert("owner".to_string(), model.owner.clone());
982        metadata.insert("name".to_string(), model.name.clone());
983        if let Some(desc) = &model.description {
984            metadata.insert("description".to_string(), desc.clone());
985        }
986        if let Some(github) = &model.github_url {
987            metadata.insert("github_url".to_string(), github.clone());
988        }
989        if let Some(paper) = &model.paper_url {
990            metadata.insert("paper_url".to_string(), paper.clone());
991        }
992        metadata.insert("source".to_string(), "replicate".to_string());
993
994        let timestamp = model
995            .latest_version
996            .as_ref()
997            .and_then(|v| v.created_at.as_ref())
998            .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
999            .map(|dt| dt.with_timezone(&Utc))
1000            .unwrap_or_else(Utc::now);
1001
1002        SemanticVector {
1003            id: format!("replicate:{}/{}", model.owner, model.name),
1004            embedding,
1005            domain: Domain::Research,
1006            timestamp,
1007            metadata,
1008        }
1009    }
1010
1011    fn mock_model(&self, owner: &str, name: &str) -> ReplicateModel {
1012        ReplicateModel {
1013            owner: owner.to_string(),
1014            name: name.to_string(),
1015            description: Some("Mock model for testing".to_string()),
1016            visibility: Some("public".to_string()),
1017            github_url: None,
1018            paper_url: None,
1019            latest_version: Some(ReplicateVersion {
1020                id: "mock-version-123".to_string(),
1021                created_at: Some(Utc::now().to_rfc3339()),
1022            }),
1023        }
1024    }
1025
1026    fn mock_prediction(&self) -> ReplicatePrediction {
1027        ReplicatePrediction {
1028            id: "mock-prediction-123".to_string(),
1029            status: "succeeded".to_string(),
1030            output: Some(serde_json::json!({"result": "mock output"})),
1031            error: None,
1032        }
1033    }
1034
1035    fn mock_collections(&self) -> Vec<ReplicateCollection> {
1036        vec![
1037            ReplicateCollection {
1038                name: "Text to Image".to_string(),
1039                slug: "text-to-image".to_string(),
1040                description: Some("Generate images from text".to_string()),
1041            },
1042            ReplicateCollection {
1043                name: "Image to Text".to_string(),
1044                slug: "image-to-text".to_string(),
1045                description: Some("Generate text from images".to_string()),
1046            },
1047        ]
1048    }
1049
1050    async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1051        let mut retries = 0;
1052        loop {
1053            let mut request = self.client.get(url);
1054
1055            if let Some(token) = &self.api_token {
1056                request = request.header("Authorization", format!("Token {}", token));
1057            }
1058
1059            match request.send().await {
1060                Ok(response) => {
1061                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1062                        retries += 1;
1063                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1064                        continue;
1065                    }
1066                    if !response.status().is_success() {
1067                        return Err(FrameworkError::Network(
1068                            reqwest::Error::from(response.error_for_status().unwrap_err()),
1069                        ));
1070                    }
1071                    return Ok(response);
1072                }
1073                Err(_) if retries < MAX_RETRIES => {
1074                    retries += 1;
1075                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1076                }
1077                Err(e) => return Err(FrameworkError::Network(e)),
1078            }
1079        }
1080    }
1081
1082    async fn fetch_with_retry_post<T: Serialize>(
1083        &self,
1084        url: &str,
1085        body: &T,
1086    ) -> Result<reqwest::Response> {
1087        let mut retries = 0;
1088        loop {
1089            let mut request = self.client.post(url).json(body);
1090
1091            if let Some(token) = &self.api_token {
1092                request = request.header("Authorization", format!("Token {}", token));
1093            }
1094
1095            match request.send().await {
1096                Ok(response) => {
1097                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1098                        retries += 1;
1099                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1100                        continue;
1101                    }
1102                    if !response.status().is_success() {
1103                        return Err(FrameworkError::Network(
1104                            reqwest::Error::from(response.error_for_status().unwrap_err()),
1105                        ));
1106                    }
1107                    return Ok(response);
1108                }
1109                Err(_) if retries < MAX_RETRIES => {
1110                    retries += 1;
1111                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1112                }
1113                Err(e) => return Err(FrameworkError::Network(e)),
1114            }
1115        }
1116    }
1117}
1118
1119impl Default for ReplicateClient {
1120    fn default() -> Self {
1121        Self::new()
1122    }
1123}
1124
1125// ============================================================================
1126// TogetherAI Client
1127// ============================================================================
1128
1129/// TogetherAI model information
1130#[derive(Debug, Clone, Deserialize, Serialize)]
1131pub struct TogetherModel {
1132    pub id: String,
1133    pub name: Option<String>,
1134    #[serde(rename = "display_name")]
1135    pub display_name: Option<String>,
1136    pub description: Option<String>,
1137    pub context_length: Option<u64>,
1138    pub pricing: Option<TogetherPricing>,
1139}
1140
1141#[derive(Debug, Clone, Deserialize, Serialize)]
1142pub struct TogetherPricing {
1143    pub input: Option<f64>,
1144    pub output: Option<f64>,
1145}
1146
1147/// TogetherAI chat completion request
1148#[derive(Debug, Clone, Serialize)]
1149pub struct TogetherChatRequest {
1150    pub model: String,
1151    pub messages: Vec<TogetherMessage>,
1152    pub max_tokens: Option<u32>,
1153    pub temperature: Option<f32>,
1154}
1155
1156#[derive(Debug, Clone, Serialize, Deserialize)]
1157pub struct TogetherMessage {
1158    pub role: String,
1159    pub content: String,
1160}
1161
1162/// TogetherAI chat completion response
1163#[derive(Debug, Clone, Deserialize)]
1164pub struct TogetherChatResponse {
1165    pub id: String,
1166    pub choices: Vec<TogetherChoice>,
1167    pub usage: Option<TogetherUsage>,
1168}
1169
1170#[derive(Debug, Clone, Deserialize)]
1171pub struct TogetherChoice {
1172    pub message: TogetherMessage,
1173    pub finish_reason: Option<String>,
1174}
1175
1176#[derive(Debug, Clone, Deserialize)]
1177pub struct TogetherUsage {
1178    pub prompt_tokens: u32,
1179    pub completion_tokens: u32,
1180    pub total_tokens: u32,
1181}
1182
1183/// TogetherAI embeddings request
1184#[derive(Debug, Clone, Serialize)]
1185pub struct TogetherEmbeddingsRequest {
1186    pub model: String,
1187    pub input: String,
1188}
1189
1190/// TogetherAI embeddings response
1191#[derive(Debug, Clone, Deserialize)]
1192pub struct TogetherEmbeddingsResponse {
1193    pub data: Vec<TogetherEmbeddingData>,
1194}
1195
1196#[derive(Debug, Clone, Deserialize)]
1197pub struct TogetherEmbeddingData {
1198    pub embedding: Vec<f32>,
1199    pub index: u32,
1200}
1201
1202/// Client for TogetherAI open source model hosting
1203///
1204/// # API Details
1205/// - Base URL: https://api.together.xyz/v1
1206/// - Requires API key
1207/// - Falls back to mock data when no API key is available
1208///
1209/// # Environment Variables
1210/// - `TOGETHER_API_KEY`: Required API key
1211pub struct TogetherAiClient {
1212    client: Client,
1213    embedder: SimpleEmbedder,
1214    base_url: String,
1215    api_key: Option<String>,
1216    use_mock: bool,
1217}
1218
1219impl TogetherAiClient {
1220    /// Create a new TogetherAI client
1221    ///
1222    /// Reads API key from `TOGETHER_API_KEY` environment variable.
1223    pub fn new() -> Self {
1224        Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
1225    }
1226
1227    /// Create a new TogetherAI client with custom embedding dimension
1228    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
1229        let api_key = env::var("TOGETHER_API_KEY").ok();
1230        let use_mock = api_key.is_none();
1231
1232        if use_mock {
1233            tracing::warn!("TOGETHER_API_KEY not set, using mock data");
1234        }
1235
1236        Self {
1237            client: Client::builder()
1238                .user_agent("RuVector-Discovery/1.0")
1239                .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
1240                .build()
1241                .expect("Failed to create HTTP client"),
1242            embedder: SimpleEmbedder::new(embedding_dim),
1243            base_url: "https://api.together.xyz/v1".to_string(),
1244            api_key,
1245            use_mock,
1246        }
1247    }
1248
1249    /// List available models
1250    pub async fn list_models(&self) -> Result<Vec<TogetherModel>> {
1251        if self.use_mock {
1252            return Ok(self.mock_models());
1253        }
1254
1255        sleep(Duration::from_millis(TOGETHER_RATE_LIMIT_MS)).await;
1256
1257        let url = format!("{}/models", self.base_url);
1258        let response = self.fetch_with_retry(&url).await?;
1259        let models: Vec<TogetherModel> = response.json().await?;
1260
1261        Ok(models)
1262    }
1263
1264    /// Chat completion
1265    ///
1266    /// # Arguments
1267    /// * `model` - Model identifier
1268    /// * `messages` - Chat message history
1269    pub async fn chat_completion(
1270        &self,
1271        model: &str,
1272        messages: Vec<TogetherMessage>,
1273    ) -> Result<String> {
1274        if self.use_mock {
1275            let last_msg = messages.last().map(|m| m.content.as_str()).unwrap_or("");
1276            return Ok(format!("Mock response to: {}", last_msg));
1277        }
1278
1279        sleep(Duration::from_millis(TOGETHER_RATE_LIMIT_MS)).await;
1280
1281        let url = format!("{}/chat/completions", self.base_url);
1282        let body = TogetherChatRequest {
1283            model: model.to_string(),
1284            messages,
1285            max_tokens: Some(512),
1286            temperature: Some(0.7),
1287        };
1288
1289        let response = self.fetch_with_retry_post(&url, &body).await?;
1290        let result: TogetherChatResponse = response.json().await?;
1291
1292        Ok(result
1293            .choices
1294            .first()
1295            .map(|c| c.message.content.clone())
1296            .unwrap_or_default())
1297    }
1298
1299    /// Generate embeddings
1300    ///
1301    /// # Arguments
1302    /// * `model` - Embedding model identifier
1303    /// * `input` - Text to embed
1304    pub async fn embeddings(&self, model: &str, input: &str) -> Result<Vec<f32>> {
1305        if self.use_mock {
1306            return Ok(self.embedder.embed_text(input));
1307        }
1308
1309        sleep(Duration::from_millis(TOGETHER_RATE_LIMIT_MS)).await;
1310
1311        let url = format!("{}/embeddings", self.base_url);
1312        let body = TogetherEmbeddingsRequest {
1313            model: model.to_string(),
1314            input: input.to_string(),
1315        };
1316
1317        let response = self.fetch_with_retry_post(&url, &body).await?;
1318        let result: TogetherEmbeddingsResponse = response.json().await?;
1319
1320        Ok(result
1321            .data
1322            .first()
1323            .map(|d| d.embedding.clone())
1324            .unwrap_or_default())
1325    }
1326
1327    /// Convert TogetherAI model to SemanticVector
1328    pub fn model_to_vector(&self, model: &TogetherModel) -> SemanticVector {
1329        let text = format!(
1330            "{} {}",
1331            model.display_name.as_deref().unwrap_or(&model.id),
1332            model.description.as_deref().unwrap_or("")
1333        );
1334
1335        let embedding = self.embedder.embed_text(&text);
1336
1337        let mut metadata = HashMap::new();
1338        metadata.insert("model_id".to_string(), model.id.clone());
1339        if let Some(name) = &model.display_name {
1340            metadata.insert("display_name".to_string(), name.clone());
1341        }
1342        if let Some(ctx) = model.context_length {
1343            metadata.insert("context_length".to_string(), ctx.to_string());
1344        }
1345        metadata.insert("source".to_string(), "together".to_string());
1346
1347        SemanticVector {
1348            id: format!("together:{}", model.id),
1349            embedding,
1350            domain: Domain::Research,
1351            timestamp: Utc::now(),
1352            metadata,
1353        }
1354    }
1355
1356    fn mock_models(&self) -> Vec<TogetherModel> {
1357        vec![
1358            TogetherModel {
1359                id: "togethercomputer/llama-2-7b".to_string(),
1360                name: Some("Llama 2 7B".to_string()),
1361                display_name: Some("Llama 2 7B".to_string()),
1362                description: Some("Meta's Llama 2 7B model".to_string()),
1363                context_length: Some(4096),
1364                pricing: None,
1365            },
1366            TogetherModel {
1367                id: "mistralai/Mistral-7B-v0.1".to_string(),
1368                name: Some("Mistral 7B".to_string()),
1369                display_name: Some("Mistral 7B".to_string()),
1370                description: Some("Mistral AI's 7B model".to_string()),
1371                context_length: Some(8192),
1372                pricing: None,
1373            },
1374        ]
1375    }
1376
1377    async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1378        let mut retries = 0;
1379        loop {
1380            let mut request = self.client.get(url);
1381
1382            if let Some(key) = &self.api_key {
1383                request = request.header("Authorization", format!("Bearer {}", key));
1384            }
1385
1386            match request.send().await {
1387                Ok(response) => {
1388                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1389                        retries += 1;
1390                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1391                        continue;
1392                    }
1393                    if !response.status().is_success() {
1394                        return Err(FrameworkError::Network(
1395                            reqwest::Error::from(response.error_for_status().unwrap_err()),
1396                        ));
1397                    }
1398                    return Ok(response);
1399                }
1400                Err(_) if retries < MAX_RETRIES => {
1401                    retries += 1;
1402                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1403                }
1404                Err(e) => return Err(FrameworkError::Network(e)),
1405            }
1406        }
1407    }
1408
1409    async fn fetch_with_retry_post<T: Serialize>(
1410        &self,
1411        url: &str,
1412        body: &T,
1413    ) -> Result<reqwest::Response> {
1414        let mut retries = 0;
1415        loop {
1416            let mut request = self.client.post(url).json(body);
1417
1418            if let Some(key) = &self.api_key {
1419                request = request.header("Authorization", format!("Bearer {}", key));
1420            }
1421
1422            match request.send().await {
1423                Ok(response) => {
1424                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1425                        retries += 1;
1426                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1427                        continue;
1428                    }
1429                    if !response.status().is_success() {
1430                        return Err(FrameworkError::Network(
1431                            reqwest::Error::from(response.error_for_status().unwrap_err()),
1432                        ));
1433                    }
1434                    return Ok(response);
1435                }
1436                Err(_) if retries < MAX_RETRIES => {
1437                    retries += 1;
1438                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1439                }
1440                Err(e) => return Err(FrameworkError::Network(e)),
1441            }
1442        }
1443    }
1444}
1445
1446impl Default for TogetherAiClient {
1447    fn default() -> Self {
1448        Self::new()
1449    }
1450}
1451
1452// ============================================================================
1453// Papers With Code Client
1454// ============================================================================
1455
1456/// Papers With Code paper
1457#[derive(Debug, Clone, Deserialize, Serialize)]
1458pub struct PaperWithCodePaper {
1459    pub id: String,
1460    pub title: String,
1461    pub abstract_text: Option<String>,
1462    pub url_abs: Option<String>,
1463    pub url_pdf: Option<String>,
1464    pub published: Option<String>,
1465    pub authors: Option<Vec<PaperAuthor>>,
1466}
1467
1468#[derive(Debug, Clone, Deserialize, Serialize)]
1469pub struct PaperAuthor {
1470    pub name: String,
1471}
1472
1473/// Papers With Code dataset
1474#[derive(Debug, Clone, Deserialize, Serialize)]
1475pub struct PaperWithCodeDataset {
1476    pub id: String,
1477    pub name: String,
1478    pub full_name: Option<String>,
1479    pub description: Option<String>,
1480    pub url: Option<String>,
1481    pub paper: Option<String>,
1482}
1483
1484/// Papers With Code SOTA (State of the Art) benchmark result
1485#[derive(Debug, Clone, Deserialize, Serialize)]
1486pub struct SotaEntry {
1487    pub task: String,
1488    pub dataset: String,
1489    pub metric: String,
1490    pub value: f64,
1491    pub paper_title: Option<String>,
1492    pub paper_url: Option<String>,
1493}
1494
1495/// Papers With Code method/technique
1496#[derive(Debug, Clone, Deserialize, Serialize)]
1497pub struct Method {
1498    pub name: String,
1499    pub full_name: Option<String>,
1500    pub description: Option<String>,
1501    pub paper: Option<String>,
1502}
1503
1504/// Papers With Code search results
1505#[derive(Debug, Clone, Deserialize)]
1506pub struct PapersSearchResponse {
1507    pub results: Vec<PaperWithCodePaper>,
1508    pub count: Option<u32>,
1509}
1510
1511/// Papers With Code datasets list response
1512#[derive(Debug, Clone, Deserialize)]
1513pub struct DatasetsResponse {
1514    pub results: Vec<PaperWithCodeDataset>,
1515    pub count: Option<u32>,
1516}
1517
1518/// Client for Papers With Code ML research database
1519///
1520/// # API Details
1521/// - Base URL: https://paperswithcode.com/api/v1
1522/// - Rate limit: 60 requests/minute
1523/// - No API key required
1524pub struct PapersWithCodeClient {
1525    client: Client,
1526    embedder: SimpleEmbedder,
1527    base_url: String,
1528}
1529
1530impl PapersWithCodeClient {
1531    /// Create a new Papers With Code client
1532    pub fn new() -> Self {
1533        Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
1534    }
1535
1536    /// Create a new Papers With Code client with custom embedding dimension
1537    pub fn with_embedding_dim(embedding_dim: usize) -> Self {
1538        Self {
1539            client: Client::builder()
1540                .user_agent("RuVector-Discovery/1.0")
1541                .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
1542                .build()
1543                .expect("Failed to create HTTP client"),
1544            embedder: SimpleEmbedder::new(embedding_dim),
1545            base_url: "https://paperswithcode.com/api/v1".to_string(),
1546        }
1547    }
1548
1549    /// Search papers by query
1550    ///
1551    /// # Arguments
1552    /// * `query` - Search query string
1553    pub async fn search_papers(&self, query: &str) -> Result<Vec<PaperWithCodePaper>> {
1554        sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1555
1556        let url = format!(
1557            "{}/papers/?q={}",
1558            self.base_url,
1559            urlencoding::encode(query)
1560        );
1561
1562        let response = self.fetch_with_retry(&url).await?;
1563        let data: PapersSearchResponse = response.json().await?;
1564
1565        Ok(data.results)
1566    }
1567
1568    /// Get paper by ID
1569    ///
1570    /// # Arguments
1571    /// * `paper_id` - Paper identifier
1572    pub async fn get_paper(&self, paper_id: &str) -> Result<Option<PaperWithCodePaper>> {
1573        sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1574
1575        let url = format!("{}/papers/{}/", self.base_url, paper_id);
1576        let response = self.fetch_with_retry(&url).await?;
1577        let paper: PaperWithCodePaper = response.json().await?;
1578
1579        Ok(Some(paper))
1580    }
1581
1582    /// List datasets
1583    pub async fn list_datasets(&self) -> Result<Vec<PaperWithCodeDataset>> {
1584        sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1585
1586        let url = format!("{}/datasets/", self.base_url);
1587        let response = self.fetch_with_retry(&url).await?;
1588        let data: DatasetsResponse = response.json().await?;
1589
1590        Ok(data.results)
1591    }
1592
1593    /// Get state-of-the-art results for a task
1594    ///
1595    /// # Arguments
1596    /// * `task` - Task name (e.g., "image-classification", "question-answering")
1597    pub async fn get_sota(&self, task: &str) -> Result<Vec<SotaEntry>> {
1598        sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1599
1600        let url = format!("{}/sota/?task={}", self.base_url, urlencoding::encode(task));
1601
1602        // Papers With Code API might not have a direct SOTA endpoint in v1
1603        // Return mock data for now
1604        Ok(self.mock_sota(task))
1605    }
1606
1607    /// Search methods/techniques
1608    ///
1609    /// # Arguments
1610    /// * `query` - Search query for methods
1611    pub async fn search_methods(&self, query: &str) -> Result<Vec<Method>> {
1612        sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1613
1614        // Return mock data as the methods endpoint structure may vary
1615        Ok(self.mock_methods(query))
1616    }
1617
1618    /// Convert paper to SemanticVector
1619    pub fn paper_to_vector(&self, paper: &PaperWithCodePaper) -> SemanticVector {
1620        let text = format!(
1621            "{} {}",
1622            paper.title,
1623            paper.abstract_text.as_deref().unwrap_or("")
1624        );
1625
1626        let embedding = self.embedder.embed_text(&text);
1627
1628        let mut metadata = HashMap::new();
1629        metadata.insert("paper_id".to_string(), paper.id.clone());
1630        metadata.insert("title".to_string(), paper.title.clone());
1631        if let Some(url) = &paper.url_abs {
1632            metadata.insert("url".to_string(), url.clone());
1633        }
1634        if let Some(pdf) = &paper.url_pdf {
1635            metadata.insert("pdf_url".to_string(), pdf.clone());
1636        }
1637        if let Some(authors) = &paper.authors {
1638            let author_names = authors
1639                .iter()
1640                .map(|a| a.name.as_str())
1641                .collect::<Vec<_>>()
1642                .join(", ");
1643            metadata.insert("authors".to_string(), author_names);
1644        }
1645        metadata.insert("source".to_string(), "paperswithcode".to_string());
1646
1647        let timestamp = paper
1648            .published
1649            .as_ref()
1650            .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
1651            .map(|dt| dt.with_timezone(&Utc))
1652            .unwrap_or_else(Utc::now);
1653
1654        SemanticVector {
1655            id: format!("pwc:paper:{}", paper.id),
1656            embedding,
1657            domain: Domain::Research,
1658            timestamp,
1659            metadata,
1660        }
1661    }
1662
1663    /// Convert dataset to SemanticVector
1664    pub fn dataset_to_vector(&self, dataset: &PaperWithCodeDataset) -> SemanticVector {
1665        let text = format!(
1666            "{} {}",
1667            dataset.name,
1668            dataset.description.as_deref().unwrap_or("")
1669        );
1670
1671        let embedding = self.embedder.embed_text(&text);
1672
1673        let mut metadata = HashMap::new();
1674        metadata.insert("dataset_id".to_string(), dataset.id.clone());
1675        metadata.insert("name".to_string(), dataset.name.clone());
1676        if let Some(desc) = &dataset.description {
1677            metadata.insert("description".to_string(), desc.clone());
1678        }
1679        if let Some(url) = &dataset.url {
1680            metadata.insert("url".to_string(), url.clone());
1681        }
1682        metadata.insert("source".to_string(), "paperswithcode".to_string());
1683
1684        SemanticVector {
1685            id: format!("pwc:dataset:{}", dataset.id),
1686            embedding,
1687            domain: Domain::Research,
1688            timestamp: Utc::now(),
1689            metadata,
1690        }
1691    }
1692
1693    fn mock_sota(&self, task: &str) -> Vec<SotaEntry> {
1694        vec![
1695            SotaEntry {
1696                task: task.to_string(),
1697                dataset: "ImageNet".to_string(),
1698                metric: "Top-1 Accuracy".to_string(),
1699                value: 90.2,
1700                paper_title: Some("Vision Transformer".to_string()),
1701                paper_url: Some("https://arxiv.org/abs/2010.11929".to_string()),
1702            },
1703            SotaEntry {
1704                task: task.to_string(),
1705                dataset: "COCO".to_string(),
1706                metric: "mAP".to_string(),
1707                value: 58.7,
1708                paper_title: Some("DETR".to_string()),
1709                paper_url: Some("https://arxiv.org/abs/2005.12872".to_string()),
1710            },
1711        ]
1712    }
1713
1714    fn mock_methods(&self, query: &str) -> Vec<Method> {
1715        vec![
1716            Method {
1717                name: format!("Transformer-{}", query),
1718                full_name: Some("Transformer Architecture".to_string()),
1719                description: Some("Attention-based neural network architecture".to_string()),
1720                paper: Some("https://arxiv.org/abs/1706.03762".to_string()),
1721            },
1722            Method {
1723                name: format!("ResNet-{}", query),
1724                full_name: Some("Residual Network".to_string()),
1725                description: Some("Deep residual learning framework".to_string()),
1726                paper: Some("https://arxiv.org/abs/1512.03385".to_string()),
1727            },
1728        ]
1729    }
1730
1731    async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1732        let mut retries = 0;
1733        loop {
1734            match self.client.get(url).send().await {
1735                Ok(response) => {
1736                    if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1737                        retries += 1;
1738                        tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
1739                        sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1740                        continue;
1741                    }
1742                    if !response.status().is_success() {
1743                        return Err(FrameworkError::Network(
1744                            reqwest::Error::from(response.error_for_status().unwrap_err()),
1745                        ));
1746                    }
1747                    return Ok(response);
1748                }
1749                Err(_) if retries < MAX_RETRIES => {
1750                    retries += 1;
1751                    tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
1752                    sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1753                }
1754                Err(e) => return Err(FrameworkError::Network(e)),
1755            }
1756        }
1757    }
1758}
1759
1760impl Default for PapersWithCodeClient {
1761    fn default() -> Self {
1762        Self::new()
1763    }
1764}
1765
1766// ============================================================================
1767// Tests
1768// ============================================================================
1769
1770#[cfg(test)]
1771mod tests {
1772    use super::*;
1773
1774    // HuggingFace Tests
1775    #[test]
1776    fn test_huggingface_client_creation() {
1777        let client = HuggingFaceClient::new();
1778        assert_eq!(client.base_url, "https://huggingface.co/api");
1779    }
1780
1781    #[test]
1782    fn test_huggingface_mock_models() {
1783        let client = HuggingFaceClient::new();
1784        let models = client.mock_models("test");
1785        assert!(!models.is_empty());
1786        assert!(models[0].model_id.contains("test"));
1787    }
1788
1789    #[test]
1790    fn test_huggingface_model_to_vector() {
1791        let client = HuggingFaceClient::new();
1792        let model = HuggingFaceModel {
1793            model_id: "bert-base-uncased".to_string(),
1794            author: Some("google".to_string()),
1795            downloads: Some(1_000_000),
1796            likes: Some(500),
1797            tags: Some(vec!["nlp".to_string()]),
1798            pipeline_tag: Some("fill-mask".to_string()),
1799            created_at: Some(Utc::now().to_rfc3339()),
1800        };
1801
1802        let vector = client.model_to_vector(&model);
1803        assert_eq!(vector.id, "hf:model:bert-base-uncased");
1804        assert_eq!(vector.domain, Domain::Research);
1805        assert!(vector.metadata.contains_key("model_id"));
1806        assert_eq!(vector.metadata.get("author").unwrap(), "google");
1807    }
1808
1809    #[tokio::test]
1810    async fn test_huggingface_search_models_mock() {
1811        let client = HuggingFaceClient::new();
1812        let models = client.search_models("bert", None).await;
1813        assert!(models.is_ok());
1814        assert!(!models.unwrap().is_empty());
1815    }
1816
1817    // Ollama Tests
1818    #[test]
1819    fn test_ollama_client_creation() {
1820        let client = OllamaClient::new();
1821        assert_eq!(client.base_url, "http://localhost:11434/api");
1822    }
1823
1824    #[test]
1825    fn test_ollama_mock_models() {
1826        let client = OllamaClient::new();
1827        let models = client.mock_models();
1828        assert!(!models.is_empty());
1829        assert!(models[0].name.contains("llama"));
1830    }
1831
1832    #[test]
1833    fn test_ollama_model_to_vector() {
1834        let client = OllamaClient::new();
1835        let model = OllamaModel {
1836            name: "llama2:latest".to_string(),
1837            modified_at: Some(Utc::now().to_rfc3339()),
1838            size: Some(3_800_000_000),
1839            digest: Some("sha256:abc123".to_string()),
1840        };
1841
1842        let vector = client.model_to_vector(&model);
1843        assert_eq!(vector.id, "ollama:model:llama2:latest");
1844        assert_eq!(vector.domain, Domain::Research);
1845        assert!(vector.metadata.contains_key("model_name"));
1846    }
1847
1848    #[tokio::test]
1849    async fn test_ollama_list_models_mock() {
1850        let mut client = OllamaClient::new();
1851        client.use_mock = true;
1852        let models = client.list_models().await;
1853        assert!(models.is_ok());
1854        assert!(!models.unwrap().is_empty());
1855    }
1856
1857    #[tokio::test]
1858    async fn test_ollama_embeddings_mock() {
1859        let mut client = OllamaClient::new();
1860        client.use_mock = true;
1861        let embedding = client.embeddings("llama2", "test text").await;
1862        assert!(embedding.is_ok());
1863        assert_eq!(embedding.unwrap().len(), DEFAULT_EMBEDDING_DIM);
1864    }
1865
1866    // Replicate Tests
1867    #[test]
1868    fn test_replicate_client_creation() {
1869        let client = ReplicateClient::new();
1870        assert_eq!(client.base_url, "https://api.replicate.com/v1");
1871    }
1872
1873    #[test]
1874    fn test_replicate_mock_model() {
1875        let client = ReplicateClient::new();
1876        let model = client.mock_model("owner", "model");
1877        assert_eq!(model.owner, "owner");
1878        assert_eq!(model.name, "model");
1879    }
1880
1881    #[test]
1882    fn test_replicate_model_to_vector() {
1883        let client = ReplicateClient::new();
1884        let model = ReplicateModel {
1885            owner: "stability-ai".to_string(),
1886            name: "stable-diffusion".to_string(),
1887            description: Some("Text to image model".to_string()),
1888            visibility: Some("public".to_string()),
1889            github_url: None,
1890            paper_url: None,
1891            latest_version: Some(ReplicateVersion {
1892                id: "v1.0".to_string(),
1893                created_at: Some(Utc::now().to_rfc3339()),
1894            }),
1895        };
1896
1897        let vector = client.model_to_vector(&model);
1898        assert_eq!(vector.id, "replicate:stability-ai/stable-diffusion");
1899        assert_eq!(vector.domain, Domain::Research);
1900    }
1901
1902    #[tokio::test]
1903    async fn test_replicate_get_model_mock() {
1904        let client = ReplicateClient::new();
1905        let model = client.get_model("owner", "model").await;
1906        assert!(model.is_ok());
1907        assert!(model.unwrap().is_some());
1908    }
1909
1910    // TogetherAI Tests
1911    #[test]
1912    fn test_together_client_creation() {
1913        let client = TogetherAiClient::new();
1914        assert_eq!(client.base_url, "https://api.together.xyz/v1");
1915    }
1916
1917    #[test]
1918    fn test_together_mock_models() {
1919        let client = TogetherAiClient::new();
1920        let models = client.mock_models();
1921        assert!(!models.is_empty());
1922        assert!(models[0].id.contains("llama"));
1923    }
1924
1925    #[test]
1926    fn test_together_model_to_vector() {
1927        let client = TogetherAiClient::new();
1928        let model = TogetherModel {
1929            id: "togethercomputer/llama-2-7b".to_string(),
1930            name: Some("Llama 2 7B".to_string()),
1931            display_name: Some("Llama 2 7B".to_string()),
1932            description: Some("Meta's Llama 2 model".to_string()),
1933            context_length: Some(4096),
1934            pricing: None,
1935        };
1936
1937        let vector = client.model_to_vector(&model);
1938        assert_eq!(vector.id, "together:togethercomputer/llama-2-7b");
1939        assert_eq!(vector.domain, Domain::Research);
1940    }
1941
1942    #[tokio::test]
1943    async fn test_together_list_models_mock() {
1944        let client = TogetherAiClient::new();
1945        let models = client.list_models().await;
1946        assert!(models.is_ok());
1947        assert!(!models.unwrap().is_empty());
1948    }
1949
1950    // Papers With Code Tests
1951    #[test]
1952    fn test_paperswithcode_client_creation() {
1953        let client = PapersWithCodeClient::new();
1954        assert_eq!(client.base_url, "https://paperswithcode.com/api/v1");
1955    }
1956
1957    #[test]
1958    fn test_paperswithcode_paper_to_vector() {
1959        let client = PapersWithCodeClient::new();
1960        let paper = PaperWithCodePaper {
1961            id: "attention-is-all-you-need".to_string(),
1962            title: "Attention Is All You Need".to_string(),
1963            abstract_text: Some("We propose the Transformer...".to_string()),
1964            url_abs: Some("https://arxiv.org/abs/1706.03762".to_string()),
1965            url_pdf: Some("https://arxiv.org/pdf/1706.03762.pdf".to_string()),
1966            published: Some(Utc::now().to_rfc3339()),
1967            authors: Some(vec![
1968                PaperAuthor {
1969                    name: "Vaswani et al.".to_string(),
1970                },
1971            ]),
1972        };
1973
1974        let vector = client.paper_to_vector(&paper);
1975        assert_eq!(vector.id, "pwc:paper:attention-is-all-you-need");
1976        assert_eq!(vector.domain, Domain::Research);
1977        assert!(vector.metadata.contains_key("title"));
1978    }
1979
1980    #[test]
1981    fn test_paperswithcode_dataset_to_vector() {
1982        let client = PapersWithCodeClient::new();
1983        let dataset = PaperWithCodeDataset {
1984            id: "imagenet".to_string(),
1985            name: "ImageNet".to_string(),
1986            full_name: Some("ImageNet Large Scale Visual Recognition Challenge".to_string()),
1987            description: Some("Large-scale image dataset".to_string()),
1988            url: Some("https://image-net.org".to_string()),
1989            paper: None,
1990        };
1991
1992        let vector = client.dataset_to_vector(&dataset);
1993        assert_eq!(vector.id, "pwc:dataset:imagenet");
1994        assert_eq!(vector.domain, Domain::Research);
1995    }
1996
1997    #[tokio::test]
1998    #[ignore] // Ignore by default to avoid hitting API in tests
1999    async fn test_paperswithcode_search_papers_integration() {
2000        let client = PapersWithCodeClient::new();
2001        let papers = client.search_papers("transformer").await;
2002        assert!(papers.is_ok());
2003    }
2004
2005    // Integration Tests
2006    #[test]
2007    fn test_all_clients_default() {
2008        let _hf = HuggingFaceClient::default();
2009        let _ollama = OllamaClient::default();
2010        let _replicate = ReplicateClient::default();
2011        let _together = TogetherAiClient::default();
2012        let _pwc = PapersWithCodeClient::default();
2013    }
2014
2015    #[test]
2016    fn test_custom_embedding_dimensions() {
2017        let hf = HuggingFaceClient::with_embedding_dim(512);
2018        let model = hf.mock_models("test")[0].clone();
2019        let vector = hf.model_to_vector(&model);
2020        assert_eq!(vector.embedding.len(), 512);
2021
2022        let pwc = PapersWithCodeClient::with_embedding_dim(768);
2023        let paper = PaperWithCodePaper {
2024            id: "test".to_string(),
2025            title: "Test Paper".to_string(),
2026            abstract_text: None,
2027            url_abs: None,
2028            url_pdf: None,
2029            published: None,
2030            authors: None,
2031        };
2032        let vector = pwc.paper_to_vector(&paper);
2033        assert_eq!(vector.embedding.len(), 768);
2034    }
2035}