1use std::collections::HashMap;
14use std::env;
15use std::time::Duration;
16
17use chrono::{DateTime, Utc};
18use reqwest::{Client, StatusCode};
19use serde::{Deserialize, Serialize};
20use tokio::time::sleep;
21
22use crate::api_clients::SimpleEmbedder;
23use crate::ruvector_native::{Domain, SemanticVector};
24use crate::{FrameworkError, Result};
25
26const HUGGINGFACE_RATE_LIMIT_MS: u64 = 2000; const PAPERWITHCODE_RATE_LIMIT_MS: u64 = 1000; const REPLICATE_RATE_LIMIT_MS: u64 = 1000;
30const TOGETHER_RATE_LIMIT_MS: u64 = 1000;
31const OLLAMA_RATE_LIMIT_MS: u64 = 100; const MAX_RETRIES: u32 = 3;
34const RETRY_DELAY_MS: u64 = 2000;
35const DEFAULT_EMBEDDING_DIM: usize = 384;
36const REQUEST_TIMEOUT_SECS: u64 = 30;
37
38#[derive(Debug, Clone, Deserialize, Serialize)]
44pub struct HuggingFaceModel {
45 #[serde(rename = "modelId")]
46 pub model_id: String,
47 #[serde(rename = "author")]
48 pub author: Option<String>,
49 #[serde(rename = "downloads")]
50 pub downloads: Option<u64>,
51 #[serde(rename = "likes")]
52 pub likes: Option<u64>,
53 #[serde(rename = "tags")]
54 pub tags: Option<Vec<String>>,
55 #[serde(rename = "pipeline_tag")]
56 pub pipeline_tag: Option<String>,
57 #[serde(rename = "createdAt")]
58 pub created_at: Option<String>,
59}
60
61#[derive(Debug, Clone, Deserialize, Serialize)]
63pub struct HuggingFaceDataset {
64 pub id: String,
65 pub author: Option<String>,
66 pub downloads: Option<u64>,
67 pub likes: Option<u64>,
68 pub tags: Option<Vec<String>>,
69 #[serde(rename = "createdAt")]
70 pub created_at: Option<String>,
71 pub description: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize)]
76pub struct HuggingFaceInferenceInput {
77 pub inputs: serde_json::Value,
78}
79
80#[derive(Debug, Clone, Deserialize)]
82#[serde(untagged)]
83pub enum HuggingFaceInferenceResponse {
84 Embeddings(Vec<Vec<f32>>),
85 Classification(Vec<ClassificationResult>),
86 Generation(Vec<GenerationResult>),
87 Error(InferenceError),
88}
89
90#[derive(Debug, Clone, Deserialize)]
91pub struct ClassificationResult {
92 pub label: String,
93 pub score: f64,
94}
95
96#[derive(Debug, Clone, Deserialize)]
97pub struct GenerationResult {
98 pub generated_text: String,
99}
100
101#[derive(Debug, Clone, Deserialize)]
102pub struct InferenceError {
103 pub error: String,
104}
105
106pub struct HuggingFaceClient {
116 client: Client,
117 embedder: SimpleEmbedder,
118 base_url: String,
119 api_key: Option<String>,
120 use_mock: bool,
121}
122
123impl HuggingFaceClient {
124 pub fn new() -> Self {
129 Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
130 }
131
132 pub fn with_embedding_dim(embedding_dim: usize) -> Self {
134 let api_key = env::var("HUGGINGFACE_API_KEY").ok();
135 let use_mock = api_key.is_none();
136
137 if use_mock {
138 tracing::warn!("HUGGINGFACE_API_KEY not set, using mock data");
139 }
140
141 Self {
142 client: Client::builder()
143 .user_agent("RuVector-Discovery/1.0")
144 .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
145 .build()
146 .expect("Failed to create HTTP client"),
147 embedder: SimpleEmbedder::new(embedding_dim),
148 base_url: "https://huggingface.co/api".to_string(),
149 api_key,
150 use_mock,
151 }
152 }
153
154 pub async fn search_models(
165 &self,
166 query: &str,
167 task: Option<&str>,
168 ) -> Result<Vec<HuggingFaceModel>> {
169 if self.use_mock {
170 return Ok(self.mock_models(query));
171 }
172
173 sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
174
175 let mut url = format!("{}/models?search={}", self.base_url, urlencoding::encode(query));
176 if let Some(task_filter) = task {
177 url.push_str(&format!("&filter={}", task_filter));
178 }
179 url.push_str("&limit=20");
180
181 let response = self.fetch_with_retry(&url).await?;
182 let models: Vec<HuggingFaceModel> = response.json().await?;
183
184 Ok(models)
185 }
186
187 pub async fn get_model(&self, model_id: &str) -> Result<Option<HuggingFaceModel>> {
192 if self.use_mock {
193 return Ok(self.mock_models(model_id).into_iter().next());
194 }
195
196 sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
197
198 let url = format!("{}/models/{}", self.base_url, model_id);
199 let response = self.fetch_with_retry(&url).await?;
200 let model: HuggingFaceModel = response.json().await?;
201
202 Ok(Some(model))
203 }
204
205 pub async fn list_datasets(&self, query: Option<&str>) -> Result<Vec<HuggingFaceDataset>> {
210 if self.use_mock {
211 return Ok(self.mock_datasets(query.unwrap_or("ml")));
212 }
213
214 sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
215
216 let mut url = format!("{}/datasets", self.base_url);
217 if let Some(q) = query {
218 url.push_str(&format!("?search={}", urlencoding::encode(q)));
219 }
220 url.push_str("&limit=20");
221
222 let response = self.fetch_with_retry(&url).await?;
223 let datasets: Vec<HuggingFaceDataset> = response.json().await?;
224
225 Ok(datasets)
226 }
227
228 pub async fn get_dataset(&self, dataset_id: &str) -> Result<Option<HuggingFaceDataset>> {
233 if self.use_mock {
234 return Ok(self.mock_datasets(dataset_id).into_iter().next());
235 }
236
237 sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
238
239 let url = format!("{}/datasets/{}", self.base_url, dataset_id);
240 let response = self.fetch_with_retry(&url).await?;
241 let dataset: HuggingFaceDataset = response.json().await?;
242
243 Ok(Some(dataset))
244 }
245
246 pub async fn inference(
255 &self,
256 model_id: &str,
257 inputs: serde_json::Value,
258 ) -> Result<HuggingFaceInferenceResponse> {
259 if self.use_mock {
260 let embedding = self.embedder.embed_json(&inputs);
262 return Ok(HuggingFaceInferenceResponse::Embeddings(vec![embedding]));
263 }
264
265 sleep(Duration::from_millis(HUGGINGFACE_RATE_LIMIT_MS)).await;
266
267 let url = format!("https://api-inference.huggingface.co/models/{}", model_id);
268 let body = HuggingFaceInferenceInput { inputs };
269
270 let mut request = self.client.post(&url).json(&body);
271
272 if let Some(key) = &self.api_key {
273 request = request.header("Authorization", format!("Bearer {}", key));
274 }
275
276 let response = request.send().await?;
277
278 if !response.status().is_success() {
279 return Err(FrameworkError::Network(
280 reqwest::Error::from(response.error_for_status().unwrap_err()),
281 ));
282 }
283
284 let result: HuggingFaceInferenceResponse = response.json().await?;
285 Ok(result)
286 }
287
288 pub fn model_to_vector(&self, model: &HuggingFaceModel) -> SemanticVector {
290 let text = format!(
291 "{} {} {}",
292 model.model_id,
293 model.pipeline_tag.as_deref().unwrap_or(""),
294 model.tags.as_ref().map(|t| t.join(" ")).unwrap_or_default()
295 );
296
297 let embedding = self.embedder.embed_text(&text);
298
299 let mut metadata = HashMap::new();
300 metadata.insert("model_id".to_string(), model.model_id.clone());
301 if let Some(author) = &model.author {
302 metadata.insert("author".to_string(), author.clone());
303 }
304 if let Some(downloads) = model.downloads {
305 metadata.insert("downloads".to_string(), downloads.to_string());
306 }
307 if let Some(likes) = model.likes {
308 metadata.insert("likes".to_string(), likes.to_string());
309 }
310 if let Some(pipeline) = &model.pipeline_tag {
311 metadata.insert("task".to_string(), pipeline.clone());
312 }
313 metadata.insert("source".to_string(), "huggingface".to_string());
314
315 let timestamp = model
316 .created_at
317 .as_ref()
318 .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
319 .map(|dt| dt.with_timezone(&Utc))
320 .unwrap_or_else(Utc::now);
321
322 SemanticVector {
323 id: format!("hf:model:{}", model.model_id),
324 embedding,
325 domain: Domain::Research,
326 timestamp,
327 metadata,
328 }
329 }
330
331 pub fn dataset_to_vector(&self, dataset: &HuggingFaceDataset) -> SemanticVector {
333 let text = format!(
334 "{} {}",
335 dataset.id,
336 dataset.description.as_deref().unwrap_or("")
337 );
338
339 let embedding = self.embedder.embed_text(&text);
340
341 let mut metadata = HashMap::new();
342 metadata.insert("dataset_id".to_string(), dataset.id.clone());
343 if let Some(author) = &dataset.author {
344 metadata.insert("author".to_string(), author.clone());
345 }
346 if let Some(downloads) = dataset.downloads {
347 metadata.insert("downloads".to_string(), downloads.to_string());
348 }
349 metadata.insert("source".to_string(), "huggingface".to_string());
350
351 let timestamp = dataset
352 .created_at
353 .as_ref()
354 .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
355 .map(|dt| dt.with_timezone(&Utc))
356 .unwrap_or_else(Utc::now);
357
358 SemanticVector {
359 id: format!("hf:dataset:{}", dataset.id),
360 embedding,
361 domain: Domain::Research,
362 timestamp,
363 metadata,
364 }
365 }
366
367 fn mock_models(&self, query: &str) -> Vec<HuggingFaceModel> {
369 vec![
370 HuggingFaceModel {
371 model_id: format!("bert-base-{}", query),
372 author: Some("google".to_string()),
373 downloads: Some(1_000_000),
374 likes: Some(500),
375 tags: Some(vec!["nlp".to_string(), "transformer".to_string()]),
376 pipeline_tag: Some("fill-mask".to_string()),
377 created_at: Some(Utc::now().to_rfc3339()),
378 },
379 HuggingFaceModel {
380 model_id: format!("gpt2-{}", query),
381 author: Some("openai".to_string()),
382 downloads: Some(800_000),
383 likes: Some(350),
384 tags: Some(vec!["text-generation".to_string()]),
385 pipeline_tag: Some("text-generation".to_string()),
386 created_at: Some(Utc::now().to_rfc3339()),
387 },
388 ]
389 }
390
391 fn mock_datasets(&self, query: &str) -> Vec<HuggingFaceDataset> {
393 vec![
394 HuggingFaceDataset {
395 id: format!("squad-{}", query),
396 author: Some("datasets".to_string()),
397 downloads: Some(500_000),
398 likes: Some(200),
399 tags: Some(vec!["qa".to_string(), "english".to_string()]),
400 created_at: Some(Utc::now().to_rfc3339()),
401 description: Some("Question answering dataset".to_string()),
402 },
403 HuggingFaceDataset {
404 id: format!("glue-{}", query),
405 author: Some("datasets".to_string()),
406 downloads: Some(300_000),
407 likes: Some(150),
408 tags: Some(vec!["benchmark".to_string()]),
409 created_at: Some(Utc::now().to_rfc3339()),
410 description: Some("General Language Understanding Evaluation".to_string()),
411 },
412 ]
413 }
414
415 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
416 let mut retries = 0;
417 loop {
418 let mut request = self.client.get(url);
419
420 if let Some(key) = &self.api_key {
421 request = request.header("Authorization", format!("Bearer {}", key));
422 }
423
424 match request.send().await {
425 Ok(response) => {
426 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
427 retries += 1;
428 tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
429 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
430 continue;
431 }
432 if !response.status().is_success() {
433 return Err(FrameworkError::Network(
434 reqwest::Error::from(response.error_for_status().unwrap_err()),
435 ));
436 }
437 return Ok(response);
438 }
439 Err(_) if retries < MAX_RETRIES => {
440 retries += 1;
441 tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
442 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
443 }
444 Err(e) => return Err(FrameworkError::Network(e)),
445 }
446 }
447 }
448}
449
450impl Default for HuggingFaceClient {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456#[derive(Debug, Clone, Deserialize, Serialize)]
462pub struct OllamaModel {
463 pub name: String,
464 pub modified_at: Option<String>,
465 pub size: Option<u64>,
466 pub digest: Option<String>,
467}
468
469#[derive(Debug, Clone, Deserialize)]
471pub struct OllamaModelsResponse {
472 pub models: Vec<OllamaModel>,
473}
474
475#[derive(Debug, Clone, Serialize)]
477pub struct OllamaGenerateRequest {
478 pub model: String,
479 pub prompt: String,
480 pub stream: bool,
481}
482
483#[derive(Debug, Clone, Deserialize)]
485pub struct OllamaGenerateResponse {
486 pub model: String,
487 pub response: String,
488 pub done: bool,
489}
490
491#[derive(Debug, Clone, Serialize)]
493pub struct OllamaChatMessage {
494 pub role: String,
495 pub content: String,
496}
497
498#[derive(Debug, Clone, Serialize)]
500pub struct OllamaChatRequest {
501 pub model: String,
502 pub messages: Vec<OllamaChatMessage>,
503 pub stream: bool,
504}
505
506#[derive(Debug, Clone, Deserialize)]
508pub struct OllamaChatResponse {
509 pub model: String,
510 pub message: OllamaMessage,
511 pub done: bool,
512}
513
514#[derive(Debug, Clone, Deserialize)]
515pub struct OllamaMessage {
516 pub role: String,
517 pub content: String,
518}
519
520#[derive(Debug, Clone, Serialize)]
522pub struct OllamaEmbeddingsRequest {
523 pub model: String,
524 pub prompt: String,
525}
526
527#[derive(Debug, Clone, Deserialize)]
529pub struct OllamaEmbeddingsResponse {
530 pub embedding: Vec<f32>,
531}
532
533pub struct OllamaClient {
541 client: Client,
542 embedder: SimpleEmbedder,
543 base_url: String,
544 use_mock: bool,
545}
546
547impl OllamaClient {
548 pub fn new() -> Self {
550 Self::with_base_url("http://localhost:11434/api")
551 }
552
553 pub fn with_base_url(base_url: &str) -> Self {
558 Self {
559 client: Client::builder()
560 .user_agent("RuVector-Discovery/1.0")
561 .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
562 .build()
563 .expect("Failed to create HTTP client"),
564 embedder: SimpleEmbedder::new(DEFAULT_EMBEDDING_DIM),
565 base_url: base_url.to_string(),
566 use_mock: false,
567 }
568 }
569
570 pub async fn is_available(&self) -> bool {
572 self.client
573 .get(&format!("{}/tags", self.base_url))
574 .send()
575 .await
576 .map(|r| r.status().is_success())
577 .unwrap_or(false)
578 }
579
580 pub async fn list_models(&mut self) -> Result<Vec<OllamaModel>> {
582 sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
583
584 let url = format!("{}/tags", self.base_url);
585
586 match self.client.get(&url).send().await {
587 Ok(response) if response.status().is_success() => {
588 let data: OllamaModelsResponse = response.json().await?;
589 self.use_mock = false;
590 Ok(data.models)
591 }
592 _ => {
593 if !self.use_mock {
594 tracing::warn!("Ollama not available, using mock data");
595 self.use_mock = true;
596 }
597 Ok(self.mock_models())
598 }
599 }
600 }
601
602 pub async fn generate(&mut self, model: &str, prompt: &str) -> Result<String> {
608 if self.use_mock || !self.is_available().await {
609 self.use_mock = true;
610 return Ok(self.mock_generation(prompt));
611 }
612
613 sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
614
615 let url = format!("{}/generate", self.base_url);
616 let body = OllamaGenerateRequest {
617 model: model.to_string(),
618 prompt: prompt.to_string(),
619 stream: false,
620 };
621
622 let response = self.client.post(&url).json(&body).send().await?;
623
624 if !response.status().is_success() {
625 return Err(FrameworkError::Network(
626 reqwest::Error::from(response.error_for_status().unwrap_err()),
627 ));
628 }
629
630 let result: OllamaGenerateResponse = response.json().await?;
631 Ok(result.response)
632 }
633
634 pub async fn chat(
640 &mut self,
641 model: &str,
642 messages: Vec<OllamaChatMessage>,
643 ) -> Result<String> {
644 if self.use_mock || !self.is_available().await {
645 self.use_mock = true;
646 let last_msg = messages.last().map(|m| m.content.as_str()).unwrap_or("");
647 return Ok(self.mock_generation(last_msg));
648 }
649
650 sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
651
652 let url = format!("{}/chat", self.base_url);
653 let body = OllamaChatRequest {
654 model: model.to_string(),
655 messages,
656 stream: false,
657 };
658
659 let response = self.client.post(&url).json(&body).send().await?;
660
661 if !response.status().is_success() {
662 return Err(FrameworkError::Network(
663 reqwest::Error::from(response.error_for_status().unwrap_err()),
664 ));
665 }
666
667 let result: OllamaChatResponse = response.json().await?;
668 Ok(result.message.content)
669 }
670
671 pub async fn embeddings(&mut self, model: &str, prompt: &str) -> Result<Vec<f32>> {
677 if self.use_mock || !self.is_available().await {
678 self.use_mock = true;
679 return Ok(self.embedder.embed_text(prompt));
680 }
681
682 sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
683
684 let url = format!("{}/embeddings", self.base_url);
685 let body = OllamaEmbeddingsRequest {
686 model: model.to_string(),
687 prompt: prompt.to_string(),
688 };
689
690 let response = self.client.post(&url).json(&body).send().await?;
691
692 if !response.status().is_success() {
693 return Err(FrameworkError::Network(
694 reqwest::Error::from(response.error_for_status().unwrap_err()),
695 ));
696 }
697
698 let result: OllamaEmbeddingsResponse = response.json().await?;
699 Ok(result.embedding)
700 }
701
702 pub async fn pull_model(&mut self, name: &str) -> Result<bool> {
710 if self.use_mock || !self.is_available().await {
711 self.use_mock = true;
712 tracing::warn!("Ollama not available, cannot pull model");
713 return Ok(false);
714 }
715
716 sleep(Duration::from_millis(OLLAMA_RATE_LIMIT_MS)).await;
717
718 let url = format!("{}/pull", self.base_url);
719 let body = serde_json::json!({ "name": name });
720
721 let response = self.client.post(&url).json(&body).send().await?;
722 Ok(response.status().is_success())
723 }
724
725 pub fn model_to_vector(&self, model: &OllamaModel) -> SemanticVector {
727 let embedding = self.embedder.embed_text(&model.name);
728
729 let mut metadata = HashMap::new();
730 metadata.insert("model_name".to_string(), model.name.clone());
731 if let Some(size) = model.size {
732 metadata.insert("size_bytes".to_string(), size.to_string());
733 }
734 if let Some(digest) = &model.digest {
735 metadata.insert("digest".to_string(), digest.clone());
736 }
737 metadata.insert("source".to_string(), "ollama".to_string());
738
739 let timestamp = model
740 .modified_at
741 .as_ref()
742 .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
743 .map(|dt| dt.with_timezone(&Utc))
744 .unwrap_or_else(Utc::now);
745
746 SemanticVector {
747 id: format!("ollama:model:{}", model.name),
748 embedding,
749 domain: Domain::Research,
750 timestamp,
751 metadata,
752 }
753 }
754
755 fn mock_models(&self) -> Vec<OllamaModel> {
756 vec![
757 OllamaModel {
758 name: "llama2:latest".to_string(),
759 modified_at: Some(Utc::now().to_rfc3339()),
760 size: Some(3_800_000_000),
761 digest: Some("sha256:mock123".to_string()),
762 },
763 OllamaModel {
764 name: "mistral:latest".to_string(),
765 modified_at: Some(Utc::now().to_rfc3339()),
766 size: Some(4_100_000_000),
767 digest: Some("sha256:mock456".to_string()),
768 },
769 ]
770 }
771
772 fn mock_generation(&self, prompt: &str) -> String {
773 format!("Mock response to: {}", prompt.chars().take(50).collect::<String>())
774 }
775}
776
777impl Default for OllamaClient {
778 fn default() -> Self {
779 Self::new()
780 }
781}
782
783#[derive(Debug, Clone, Deserialize, Serialize)]
789pub struct ReplicateModel {
790 pub owner: String,
791 pub name: String,
792 pub description: Option<String>,
793 pub visibility: Option<String>,
794 pub github_url: Option<String>,
795 pub paper_url: Option<String>,
796 pub latest_version: Option<ReplicateVersion>,
797}
798
799#[derive(Debug, Clone, Deserialize, Serialize)]
801pub struct ReplicateVersion {
802 pub id: String,
803 pub created_at: Option<String>,
804}
805
806#[derive(Debug, Clone, Serialize)]
808pub struct ReplicatePredictionRequest {
809 pub version: String,
810 pub input: serde_json::Value,
811}
812
813#[derive(Debug, Clone, Deserialize)]
815pub struct ReplicatePrediction {
816 pub id: String,
817 pub status: String,
818 pub output: Option<serde_json::Value>,
819 pub error: Option<String>,
820}
821
822#[derive(Debug, Clone, Deserialize)]
824pub struct ReplicateCollection {
825 pub name: String,
826 pub slug: String,
827 pub description: Option<String>,
828}
829
830pub struct ReplicateClient {
840 client: Client,
841 embedder: SimpleEmbedder,
842 base_url: String,
843 api_token: Option<String>,
844 use_mock: bool,
845}
846
847impl ReplicateClient {
848 pub fn new() -> Self {
852 Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
853 }
854
855 pub fn with_embedding_dim(embedding_dim: usize) -> Self {
857 let api_token = env::var("REPLICATE_API_TOKEN").ok();
858 let use_mock = api_token.is_none();
859
860 if use_mock {
861 tracing::warn!("REPLICATE_API_TOKEN not set, using mock data");
862 }
863
864 Self {
865 client: Client::builder()
866 .user_agent("RuVector-Discovery/1.0")
867 .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
868 .build()
869 .expect("Failed to create HTTP client"),
870 embedder: SimpleEmbedder::new(embedding_dim),
871 base_url: "https://api.replicate.com/v1".to_string(),
872 api_token,
873 use_mock,
874 }
875 }
876
877 pub async fn get_model(&self, owner: &str, name: &str) -> Result<Option<ReplicateModel>> {
883 if self.use_mock {
884 return Ok(Some(self.mock_model(owner, name)));
885 }
886
887 sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
888
889 let url = format!("{}/models/{}/{}", self.base_url, owner, name);
890 let response = self.fetch_with_retry(&url).await?;
891 let model: ReplicateModel = response.json().await?;
892
893 Ok(Some(model))
894 }
895
896 pub async fn create_prediction(
902 &self,
903 model: &str,
904 input: serde_json::Value,
905 ) -> Result<ReplicatePrediction> {
906 if self.use_mock {
907 return Ok(self.mock_prediction());
908 }
909
910 sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
911
912 let url = format!("{}/predictions", self.base_url);
913
914 let parts: Vec<&str> = model.split('/').collect();
916 if parts.len() != 2 {
917 return Err(FrameworkError::Config(
918 "Model must be in 'owner/name' format".to_string(),
919 ));
920 }
921
922 let model_info = self.get_model(parts[0], parts[1]).await?;
923 let version = model_info
924 .and_then(|m| m.latest_version)
925 .and_then(|v| Some(v.id))
926 .ok_or_else(|| FrameworkError::Config("Model version not found".to_string()))?;
927
928 let body = ReplicatePredictionRequest { version, input };
929
930 let response = self.fetch_with_retry_post(&url, &body).await?;
931 let prediction: ReplicatePrediction = response.json().await?;
932
933 Ok(prediction)
934 }
935
936 pub async fn get_prediction(&self, id: &str) -> Result<ReplicatePrediction> {
941 if self.use_mock {
942 return Ok(self.mock_prediction());
943 }
944
945 sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
946
947 let url = format!("{}/predictions/{}", self.base_url, id);
948 let response = self.fetch_with_retry(&url).await?;
949 let prediction: ReplicatePrediction = response.json().await?;
950
951 Ok(prediction)
952 }
953
954 pub async fn list_collections(&self) -> Result<Vec<ReplicateCollection>> {
956 if self.use_mock {
957 return Ok(self.mock_collections());
958 }
959
960 sleep(Duration::from_millis(REPLICATE_RATE_LIMIT_MS)).await;
961
962 let url = format!("{}/collections", self.base_url);
963 let response = self.fetch_with_retry(&url).await?;
964 let collections: Vec<ReplicateCollection> = response.json().await?;
965
966 Ok(collections)
967 }
968
969 pub fn model_to_vector(&self, model: &ReplicateModel) -> SemanticVector {
971 let text = format!(
972 "{}/{} {}",
973 model.owner,
974 model.name,
975 model.description.as_deref().unwrap_or("")
976 );
977
978 let embedding = self.embedder.embed_text(&text);
979
980 let mut metadata = HashMap::new();
981 metadata.insert("owner".to_string(), model.owner.clone());
982 metadata.insert("name".to_string(), model.name.clone());
983 if let Some(desc) = &model.description {
984 metadata.insert("description".to_string(), desc.clone());
985 }
986 if let Some(github) = &model.github_url {
987 metadata.insert("github_url".to_string(), github.clone());
988 }
989 if let Some(paper) = &model.paper_url {
990 metadata.insert("paper_url".to_string(), paper.clone());
991 }
992 metadata.insert("source".to_string(), "replicate".to_string());
993
994 let timestamp = model
995 .latest_version
996 .as_ref()
997 .and_then(|v| v.created_at.as_ref())
998 .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
999 .map(|dt| dt.with_timezone(&Utc))
1000 .unwrap_or_else(Utc::now);
1001
1002 SemanticVector {
1003 id: format!("replicate:{}/{}", model.owner, model.name),
1004 embedding,
1005 domain: Domain::Research,
1006 timestamp,
1007 metadata,
1008 }
1009 }
1010
1011 fn mock_model(&self, owner: &str, name: &str) -> ReplicateModel {
1012 ReplicateModel {
1013 owner: owner.to_string(),
1014 name: name.to_string(),
1015 description: Some("Mock model for testing".to_string()),
1016 visibility: Some("public".to_string()),
1017 github_url: None,
1018 paper_url: None,
1019 latest_version: Some(ReplicateVersion {
1020 id: "mock-version-123".to_string(),
1021 created_at: Some(Utc::now().to_rfc3339()),
1022 }),
1023 }
1024 }
1025
1026 fn mock_prediction(&self) -> ReplicatePrediction {
1027 ReplicatePrediction {
1028 id: "mock-prediction-123".to_string(),
1029 status: "succeeded".to_string(),
1030 output: Some(serde_json::json!({"result": "mock output"})),
1031 error: None,
1032 }
1033 }
1034
1035 fn mock_collections(&self) -> Vec<ReplicateCollection> {
1036 vec![
1037 ReplicateCollection {
1038 name: "Text to Image".to_string(),
1039 slug: "text-to-image".to_string(),
1040 description: Some("Generate images from text".to_string()),
1041 },
1042 ReplicateCollection {
1043 name: "Image to Text".to_string(),
1044 slug: "image-to-text".to_string(),
1045 description: Some("Generate text from images".to_string()),
1046 },
1047 ]
1048 }
1049
1050 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1051 let mut retries = 0;
1052 loop {
1053 let mut request = self.client.get(url);
1054
1055 if let Some(token) = &self.api_token {
1056 request = request.header("Authorization", format!("Token {}", token));
1057 }
1058
1059 match request.send().await {
1060 Ok(response) => {
1061 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1062 retries += 1;
1063 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1064 continue;
1065 }
1066 if !response.status().is_success() {
1067 return Err(FrameworkError::Network(
1068 reqwest::Error::from(response.error_for_status().unwrap_err()),
1069 ));
1070 }
1071 return Ok(response);
1072 }
1073 Err(_) if retries < MAX_RETRIES => {
1074 retries += 1;
1075 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1076 }
1077 Err(e) => return Err(FrameworkError::Network(e)),
1078 }
1079 }
1080 }
1081
1082 async fn fetch_with_retry_post<T: Serialize>(
1083 &self,
1084 url: &str,
1085 body: &T,
1086 ) -> Result<reqwest::Response> {
1087 let mut retries = 0;
1088 loop {
1089 let mut request = self.client.post(url).json(body);
1090
1091 if let Some(token) = &self.api_token {
1092 request = request.header("Authorization", format!("Token {}", token));
1093 }
1094
1095 match request.send().await {
1096 Ok(response) => {
1097 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1098 retries += 1;
1099 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1100 continue;
1101 }
1102 if !response.status().is_success() {
1103 return Err(FrameworkError::Network(
1104 reqwest::Error::from(response.error_for_status().unwrap_err()),
1105 ));
1106 }
1107 return Ok(response);
1108 }
1109 Err(_) if retries < MAX_RETRIES => {
1110 retries += 1;
1111 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1112 }
1113 Err(e) => return Err(FrameworkError::Network(e)),
1114 }
1115 }
1116 }
1117}
1118
1119impl Default for ReplicateClient {
1120 fn default() -> Self {
1121 Self::new()
1122 }
1123}
1124
1125#[derive(Debug, Clone, Deserialize, Serialize)]
1131pub struct TogetherModel {
1132 pub id: String,
1133 pub name: Option<String>,
1134 #[serde(rename = "display_name")]
1135 pub display_name: Option<String>,
1136 pub description: Option<String>,
1137 pub context_length: Option<u64>,
1138 pub pricing: Option<TogetherPricing>,
1139}
1140
1141#[derive(Debug, Clone, Deserialize, Serialize)]
1142pub struct TogetherPricing {
1143 pub input: Option<f64>,
1144 pub output: Option<f64>,
1145}
1146
1147#[derive(Debug, Clone, Serialize)]
1149pub struct TogetherChatRequest {
1150 pub model: String,
1151 pub messages: Vec<TogetherMessage>,
1152 pub max_tokens: Option<u32>,
1153 pub temperature: Option<f32>,
1154}
1155
1156#[derive(Debug, Clone, Serialize, Deserialize)]
1157pub struct TogetherMessage {
1158 pub role: String,
1159 pub content: String,
1160}
1161
1162#[derive(Debug, Clone, Deserialize)]
1164pub struct TogetherChatResponse {
1165 pub id: String,
1166 pub choices: Vec<TogetherChoice>,
1167 pub usage: Option<TogetherUsage>,
1168}
1169
1170#[derive(Debug, Clone, Deserialize)]
1171pub struct TogetherChoice {
1172 pub message: TogetherMessage,
1173 pub finish_reason: Option<String>,
1174}
1175
1176#[derive(Debug, Clone, Deserialize)]
1177pub struct TogetherUsage {
1178 pub prompt_tokens: u32,
1179 pub completion_tokens: u32,
1180 pub total_tokens: u32,
1181}
1182
1183#[derive(Debug, Clone, Serialize)]
1185pub struct TogetherEmbeddingsRequest {
1186 pub model: String,
1187 pub input: String,
1188}
1189
1190#[derive(Debug, Clone, Deserialize)]
1192pub struct TogetherEmbeddingsResponse {
1193 pub data: Vec<TogetherEmbeddingData>,
1194}
1195
1196#[derive(Debug, Clone, Deserialize)]
1197pub struct TogetherEmbeddingData {
1198 pub embedding: Vec<f32>,
1199 pub index: u32,
1200}
1201
1202pub struct TogetherAiClient {
1212 client: Client,
1213 embedder: SimpleEmbedder,
1214 base_url: String,
1215 api_key: Option<String>,
1216 use_mock: bool,
1217}
1218
1219impl TogetherAiClient {
1220 pub fn new() -> Self {
1224 Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
1225 }
1226
1227 pub fn with_embedding_dim(embedding_dim: usize) -> Self {
1229 let api_key = env::var("TOGETHER_API_KEY").ok();
1230 let use_mock = api_key.is_none();
1231
1232 if use_mock {
1233 tracing::warn!("TOGETHER_API_KEY not set, using mock data");
1234 }
1235
1236 Self {
1237 client: Client::builder()
1238 .user_agent("RuVector-Discovery/1.0")
1239 .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
1240 .build()
1241 .expect("Failed to create HTTP client"),
1242 embedder: SimpleEmbedder::new(embedding_dim),
1243 base_url: "https://api.together.xyz/v1".to_string(),
1244 api_key,
1245 use_mock,
1246 }
1247 }
1248
1249 pub async fn list_models(&self) -> Result<Vec<TogetherModel>> {
1251 if self.use_mock {
1252 return Ok(self.mock_models());
1253 }
1254
1255 sleep(Duration::from_millis(TOGETHER_RATE_LIMIT_MS)).await;
1256
1257 let url = format!("{}/models", self.base_url);
1258 let response = self.fetch_with_retry(&url).await?;
1259 let models: Vec<TogetherModel> = response.json().await?;
1260
1261 Ok(models)
1262 }
1263
1264 pub async fn chat_completion(
1270 &self,
1271 model: &str,
1272 messages: Vec<TogetherMessage>,
1273 ) -> Result<String> {
1274 if self.use_mock {
1275 let last_msg = messages.last().map(|m| m.content.as_str()).unwrap_or("");
1276 return Ok(format!("Mock response to: {}", last_msg));
1277 }
1278
1279 sleep(Duration::from_millis(TOGETHER_RATE_LIMIT_MS)).await;
1280
1281 let url = format!("{}/chat/completions", self.base_url);
1282 let body = TogetherChatRequest {
1283 model: model.to_string(),
1284 messages,
1285 max_tokens: Some(512),
1286 temperature: Some(0.7),
1287 };
1288
1289 let response = self.fetch_with_retry_post(&url, &body).await?;
1290 let result: TogetherChatResponse = response.json().await?;
1291
1292 Ok(result
1293 .choices
1294 .first()
1295 .map(|c| c.message.content.clone())
1296 .unwrap_or_default())
1297 }
1298
1299 pub async fn embeddings(&self, model: &str, input: &str) -> Result<Vec<f32>> {
1305 if self.use_mock {
1306 return Ok(self.embedder.embed_text(input));
1307 }
1308
1309 sleep(Duration::from_millis(TOGETHER_RATE_LIMIT_MS)).await;
1310
1311 let url = format!("{}/embeddings", self.base_url);
1312 let body = TogetherEmbeddingsRequest {
1313 model: model.to_string(),
1314 input: input.to_string(),
1315 };
1316
1317 let response = self.fetch_with_retry_post(&url, &body).await?;
1318 let result: TogetherEmbeddingsResponse = response.json().await?;
1319
1320 Ok(result
1321 .data
1322 .first()
1323 .map(|d| d.embedding.clone())
1324 .unwrap_or_default())
1325 }
1326
1327 pub fn model_to_vector(&self, model: &TogetherModel) -> SemanticVector {
1329 let text = format!(
1330 "{} {}",
1331 model.display_name.as_deref().unwrap_or(&model.id),
1332 model.description.as_deref().unwrap_or("")
1333 );
1334
1335 let embedding = self.embedder.embed_text(&text);
1336
1337 let mut metadata = HashMap::new();
1338 metadata.insert("model_id".to_string(), model.id.clone());
1339 if let Some(name) = &model.display_name {
1340 metadata.insert("display_name".to_string(), name.clone());
1341 }
1342 if let Some(ctx) = model.context_length {
1343 metadata.insert("context_length".to_string(), ctx.to_string());
1344 }
1345 metadata.insert("source".to_string(), "together".to_string());
1346
1347 SemanticVector {
1348 id: format!("together:{}", model.id),
1349 embedding,
1350 domain: Domain::Research,
1351 timestamp: Utc::now(),
1352 metadata,
1353 }
1354 }
1355
1356 fn mock_models(&self) -> Vec<TogetherModel> {
1357 vec![
1358 TogetherModel {
1359 id: "togethercomputer/llama-2-7b".to_string(),
1360 name: Some("Llama 2 7B".to_string()),
1361 display_name: Some("Llama 2 7B".to_string()),
1362 description: Some("Meta's Llama 2 7B model".to_string()),
1363 context_length: Some(4096),
1364 pricing: None,
1365 },
1366 TogetherModel {
1367 id: "mistralai/Mistral-7B-v0.1".to_string(),
1368 name: Some("Mistral 7B".to_string()),
1369 display_name: Some("Mistral 7B".to_string()),
1370 description: Some("Mistral AI's 7B model".to_string()),
1371 context_length: Some(8192),
1372 pricing: None,
1373 },
1374 ]
1375 }
1376
1377 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1378 let mut retries = 0;
1379 loop {
1380 let mut request = self.client.get(url);
1381
1382 if let Some(key) = &self.api_key {
1383 request = request.header("Authorization", format!("Bearer {}", key));
1384 }
1385
1386 match request.send().await {
1387 Ok(response) => {
1388 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1389 retries += 1;
1390 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1391 continue;
1392 }
1393 if !response.status().is_success() {
1394 return Err(FrameworkError::Network(
1395 reqwest::Error::from(response.error_for_status().unwrap_err()),
1396 ));
1397 }
1398 return Ok(response);
1399 }
1400 Err(_) if retries < MAX_RETRIES => {
1401 retries += 1;
1402 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1403 }
1404 Err(e) => return Err(FrameworkError::Network(e)),
1405 }
1406 }
1407 }
1408
1409 async fn fetch_with_retry_post<T: Serialize>(
1410 &self,
1411 url: &str,
1412 body: &T,
1413 ) -> Result<reqwest::Response> {
1414 let mut retries = 0;
1415 loop {
1416 let mut request = self.client.post(url).json(body);
1417
1418 if let Some(key) = &self.api_key {
1419 request = request.header("Authorization", format!("Bearer {}", key));
1420 }
1421
1422 match request.send().await {
1423 Ok(response) => {
1424 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1425 retries += 1;
1426 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1427 continue;
1428 }
1429 if !response.status().is_success() {
1430 return Err(FrameworkError::Network(
1431 reqwest::Error::from(response.error_for_status().unwrap_err()),
1432 ));
1433 }
1434 return Ok(response);
1435 }
1436 Err(_) if retries < MAX_RETRIES => {
1437 retries += 1;
1438 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1439 }
1440 Err(e) => return Err(FrameworkError::Network(e)),
1441 }
1442 }
1443 }
1444}
1445
1446impl Default for TogetherAiClient {
1447 fn default() -> Self {
1448 Self::new()
1449 }
1450}
1451
1452#[derive(Debug, Clone, Deserialize, Serialize)]
1458pub struct PaperWithCodePaper {
1459 pub id: String,
1460 pub title: String,
1461 pub abstract_text: Option<String>,
1462 pub url_abs: Option<String>,
1463 pub url_pdf: Option<String>,
1464 pub published: Option<String>,
1465 pub authors: Option<Vec<PaperAuthor>>,
1466}
1467
1468#[derive(Debug, Clone, Deserialize, Serialize)]
1469pub struct PaperAuthor {
1470 pub name: String,
1471}
1472
1473#[derive(Debug, Clone, Deserialize, Serialize)]
1475pub struct PaperWithCodeDataset {
1476 pub id: String,
1477 pub name: String,
1478 pub full_name: Option<String>,
1479 pub description: Option<String>,
1480 pub url: Option<String>,
1481 pub paper: Option<String>,
1482}
1483
1484#[derive(Debug, Clone, Deserialize, Serialize)]
1486pub struct SotaEntry {
1487 pub task: String,
1488 pub dataset: String,
1489 pub metric: String,
1490 pub value: f64,
1491 pub paper_title: Option<String>,
1492 pub paper_url: Option<String>,
1493}
1494
1495#[derive(Debug, Clone, Deserialize, Serialize)]
1497pub struct Method {
1498 pub name: String,
1499 pub full_name: Option<String>,
1500 pub description: Option<String>,
1501 pub paper: Option<String>,
1502}
1503
1504#[derive(Debug, Clone, Deserialize)]
1506pub struct PapersSearchResponse {
1507 pub results: Vec<PaperWithCodePaper>,
1508 pub count: Option<u32>,
1509}
1510
1511#[derive(Debug, Clone, Deserialize)]
1513pub struct DatasetsResponse {
1514 pub results: Vec<PaperWithCodeDataset>,
1515 pub count: Option<u32>,
1516}
1517
1518pub struct PapersWithCodeClient {
1525 client: Client,
1526 embedder: SimpleEmbedder,
1527 base_url: String,
1528}
1529
1530impl PapersWithCodeClient {
1531 pub fn new() -> Self {
1533 Self::with_embedding_dim(DEFAULT_EMBEDDING_DIM)
1534 }
1535
1536 pub fn with_embedding_dim(embedding_dim: usize) -> Self {
1538 Self {
1539 client: Client::builder()
1540 .user_agent("RuVector-Discovery/1.0")
1541 .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
1542 .build()
1543 .expect("Failed to create HTTP client"),
1544 embedder: SimpleEmbedder::new(embedding_dim),
1545 base_url: "https://paperswithcode.com/api/v1".to_string(),
1546 }
1547 }
1548
1549 pub async fn search_papers(&self, query: &str) -> Result<Vec<PaperWithCodePaper>> {
1554 sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1555
1556 let url = format!(
1557 "{}/papers/?q={}",
1558 self.base_url,
1559 urlencoding::encode(query)
1560 );
1561
1562 let response = self.fetch_with_retry(&url).await?;
1563 let data: PapersSearchResponse = response.json().await?;
1564
1565 Ok(data.results)
1566 }
1567
1568 pub async fn get_paper(&self, paper_id: &str) -> Result<Option<PaperWithCodePaper>> {
1573 sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1574
1575 let url = format!("{}/papers/{}/", self.base_url, paper_id);
1576 let response = self.fetch_with_retry(&url).await?;
1577 let paper: PaperWithCodePaper = response.json().await?;
1578
1579 Ok(Some(paper))
1580 }
1581
1582 pub async fn list_datasets(&self) -> Result<Vec<PaperWithCodeDataset>> {
1584 sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1585
1586 let url = format!("{}/datasets/", self.base_url);
1587 let response = self.fetch_with_retry(&url).await?;
1588 let data: DatasetsResponse = response.json().await?;
1589
1590 Ok(data.results)
1591 }
1592
1593 pub async fn get_sota(&self, task: &str) -> Result<Vec<SotaEntry>> {
1598 sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1599
1600 let url = format!("{}/sota/?task={}", self.base_url, urlencoding::encode(task));
1601
1602 Ok(self.mock_sota(task))
1605 }
1606
1607 pub async fn search_methods(&self, query: &str) -> Result<Vec<Method>> {
1612 sleep(Duration::from_millis(PAPERWITHCODE_RATE_LIMIT_MS)).await;
1613
1614 Ok(self.mock_methods(query))
1616 }
1617
1618 pub fn paper_to_vector(&self, paper: &PaperWithCodePaper) -> SemanticVector {
1620 let text = format!(
1621 "{} {}",
1622 paper.title,
1623 paper.abstract_text.as_deref().unwrap_or("")
1624 );
1625
1626 let embedding = self.embedder.embed_text(&text);
1627
1628 let mut metadata = HashMap::new();
1629 metadata.insert("paper_id".to_string(), paper.id.clone());
1630 metadata.insert("title".to_string(), paper.title.clone());
1631 if let Some(url) = &paper.url_abs {
1632 metadata.insert("url".to_string(), url.clone());
1633 }
1634 if let Some(pdf) = &paper.url_pdf {
1635 metadata.insert("pdf_url".to_string(), pdf.clone());
1636 }
1637 if let Some(authors) = &paper.authors {
1638 let author_names = authors
1639 .iter()
1640 .map(|a| a.name.as_str())
1641 .collect::<Vec<_>>()
1642 .join(", ");
1643 metadata.insert("authors".to_string(), author_names);
1644 }
1645 metadata.insert("source".to_string(), "paperswithcode".to_string());
1646
1647 let timestamp = paper
1648 .published
1649 .as_ref()
1650 .and_then(|s| DateTime::parse_from_rfc3339(s).ok())
1651 .map(|dt| dt.with_timezone(&Utc))
1652 .unwrap_or_else(Utc::now);
1653
1654 SemanticVector {
1655 id: format!("pwc:paper:{}", paper.id),
1656 embedding,
1657 domain: Domain::Research,
1658 timestamp,
1659 metadata,
1660 }
1661 }
1662
1663 pub fn dataset_to_vector(&self, dataset: &PaperWithCodeDataset) -> SemanticVector {
1665 let text = format!(
1666 "{} {}",
1667 dataset.name,
1668 dataset.description.as_deref().unwrap_or("")
1669 );
1670
1671 let embedding = self.embedder.embed_text(&text);
1672
1673 let mut metadata = HashMap::new();
1674 metadata.insert("dataset_id".to_string(), dataset.id.clone());
1675 metadata.insert("name".to_string(), dataset.name.clone());
1676 if let Some(desc) = &dataset.description {
1677 metadata.insert("description".to_string(), desc.clone());
1678 }
1679 if let Some(url) = &dataset.url {
1680 metadata.insert("url".to_string(), url.clone());
1681 }
1682 metadata.insert("source".to_string(), "paperswithcode".to_string());
1683
1684 SemanticVector {
1685 id: format!("pwc:dataset:{}", dataset.id),
1686 embedding,
1687 domain: Domain::Research,
1688 timestamp: Utc::now(),
1689 metadata,
1690 }
1691 }
1692
1693 fn mock_sota(&self, task: &str) -> Vec<SotaEntry> {
1694 vec![
1695 SotaEntry {
1696 task: task.to_string(),
1697 dataset: "ImageNet".to_string(),
1698 metric: "Top-1 Accuracy".to_string(),
1699 value: 90.2,
1700 paper_title: Some("Vision Transformer".to_string()),
1701 paper_url: Some("https://arxiv.org/abs/2010.11929".to_string()),
1702 },
1703 SotaEntry {
1704 task: task.to_string(),
1705 dataset: "COCO".to_string(),
1706 metric: "mAP".to_string(),
1707 value: 58.7,
1708 paper_title: Some("DETR".to_string()),
1709 paper_url: Some("https://arxiv.org/abs/2005.12872".to_string()),
1710 },
1711 ]
1712 }
1713
1714 fn mock_methods(&self, query: &str) -> Vec<Method> {
1715 vec![
1716 Method {
1717 name: format!("Transformer-{}", query),
1718 full_name: Some("Transformer Architecture".to_string()),
1719 description: Some("Attention-based neural network architecture".to_string()),
1720 paper: Some("https://arxiv.org/abs/1706.03762".to_string()),
1721 },
1722 Method {
1723 name: format!("ResNet-{}", query),
1724 full_name: Some("Residual Network".to_string()),
1725 description: Some("Deep residual learning framework".to_string()),
1726 paper: Some("https://arxiv.org/abs/1512.03385".to_string()),
1727 },
1728 ]
1729 }
1730
1731 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1732 let mut retries = 0;
1733 loop {
1734 match self.client.get(url).send().await {
1735 Ok(response) => {
1736 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES {
1737 retries += 1;
1738 tracing::warn!("Rate limited, retrying in {}ms", RETRY_DELAY_MS * retries as u64);
1739 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1740 continue;
1741 }
1742 if !response.status().is_success() {
1743 return Err(FrameworkError::Network(
1744 reqwest::Error::from(response.error_for_status().unwrap_err()),
1745 ));
1746 }
1747 return Ok(response);
1748 }
1749 Err(_) if retries < MAX_RETRIES => {
1750 retries += 1;
1751 tracing::warn!("Request failed, retrying ({}/{})", retries, MAX_RETRIES);
1752 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1753 }
1754 Err(e) => return Err(FrameworkError::Network(e)),
1755 }
1756 }
1757 }
1758}
1759
1760impl Default for PapersWithCodeClient {
1761 fn default() -> Self {
1762 Self::new()
1763 }
1764}
1765
1766#[cfg(test)]
1771mod tests {
1772 use super::*;
1773
1774 #[test]
1776 fn test_huggingface_client_creation() {
1777 let client = HuggingFaceClient::new();
1778 assert_eq!(client.base_url, "https://huggingface.co/api");
1779 }
1780
1781 #[test]
1782 fn test_huggingface_mock_models() {
1783 let client = HuggingFaceClient::new();
1784 let models = client.mock_models("test");
1785 assert!(!models.is_empty());
1786 assert!(models[0].model_id.contains("test"));
1787 }
1788
1789 #[test]
1790 fn test_huggingface_model_to_vector() {
1791 let client = HuggingFaceClient::new();
1792 let model = HuggingFaceModel {
1793 model_id: "bert-base-uncased".to_string(),
1794 author: Some("google".to_string()),
1795 downloads: Some(1_000_000),
1796 likes: Some(500),
1797 tags: Some(vec!["nlp".to_string()]),
1798 pipeline_tag: Some("fill-mask".to_string()),
1799 created_at: Some(Utc::now().to_rfc3339()),
1800 };
1801
1802 let vector = client.model_to_vector(&model);
1803 assert_eq!(vector.id, "hf:model:bert-base-uncased");
1804 assert_eq!(vector.domain, Domain::Research);
1805 assert!(vector.metadata.contains_key("model_id"));
1806 assert_eq!(vector.metadata.get("author").unwrap(), "google");
1807 }
1808
1809 #[tokio::test]
1810 async fn test_huggingface_search_models_mock() {
1811 let client = HuggingFaceClient::new();
1812 let models = client.search_models("bert", None).await;
1813 assert!(models.is_ok());
1814 assert!(!models.unwrap().is_empty());
1815 }
1816
1817 #[test]
1819 fn test_ollama_client_creation() {
1820 let client = OllamaClient::new();
1821 assert_eq!(client.base_url, "http://localhost:11434/api");
1822 }
1823
1824 #[test]
1825 fn test_ollama_mock_models() {
1826 let client = OllamaClient::new();
1827 let models = client.mock_models();
1828 assert!(!models.is_empty());
1829 assert!(models[0].name.contains("llama"));
1830 }
1831
1832 #[test]
1833 fn test_ollama_model_to_vector() {
1834 let client = OllamaClient::new();
1835 let model = OllamaModel {
1836 name: "llama2:latest".to_string(),
1837 modified_at: Some(Utc::now().to_rfc3339()),
1838 size: Some(3_800_000_000),
1839 digest: Some("sha256:abc123".to_string()),
1840 };
1841
1842 let vector = client.model_to_vector(&model);
1843 assert_eq!(vector.id, "ollama:model:llama2:latest");
1844 assert_eq!(vector.domain, Domain::Research);
1845 assert!(vector.metadata.contains_key("model_name"));
1846 }
1847
1848 #[tokio::test]
1849 async fn test_ollama_list_models_mock() {
1850 let mut client = OllamaClient::new();
1851 client.use_mock = true;
1852 let models = client.list_models().await;
1853 assert!(models.is_ok());
1854 assert!(!models.unwrap().is_empty());
1855 }
1856
1857 #[tokio::test]
1858 async fn test_ollama_embeddings_mock() {
1859 let mut client = OllamaClient::new();
1860 client.use_mock = true;
1861 let embedding = client.embeddings("llama2", "test text").await;
1862 assert!(embedding.is_ok());
1863 assert_eq!(embedding.unwrap().len(), DEFAULT_EMBEDDING_DIM);
1864 }
1865
1866 #[test]
1868 fn test_replicate_client_creation() {
1869 let client = ReplicateClient::new();
1870 assert_eq!(client.base_url, "https://api.replicate.com/v1");
1871 }
1872
1873 #[test]
1874 fn test_replicate_mock_model() {
1875 let client = ReplicateClient::new();
1876 let model = client.mock_model("owner", "model");
1877 assert_eq!(model.owner, "owner");
1878 assert_eq!(model.name, "model");
1879 }
1880
1881 #[test]
1882 fn test_replicate_model_to_vector() {
1883 let client = ReplicateClient::new();
1884 let model = ReplicateModel {
1885 owner: "stability-ai".to_string(),
1886 name: "stable-diffusion".to_string(),
1887 description: Some("Text to image model".to_string()),
1888 visibility: Some("public".to_string()),
1889 github_url: None,
1890 paper_url: None,
1891 latest_version: Some(ReplicateVersion {
1892 id: "v1.0".to_string(),
1893 created_at: Some(Utc::now().to_rfc3339()),
1894 }),
1895 };
1896
1897 let vector = client.model_to_vector(&model);
1898 assert_eq!(vector.id, "replicate:stability-ai/stable-diffusion");
1899 assert_eq!(vector.domain, Domain::Research);
1900 }
1901
1902 #[tokio::test]
1903 async fn test_replicate_get_model_mock() {
1904 let client = ReplicateClient::new();
1905 let model = client.get_model("owner", "model").await;
1906 assert!(model.is_ok());
1907 assert!(model.unwrap().is_some());
1908 }
1909
1910 #[test]
1912 fn test_together_client_creation() {
1913 let client = TogetherAiClient::new();
1914 assert_eq!(client.base_url, "https://api.together.xyz/v1");
1915 }
1916
1917 #[test]
1918 fn test_together_mock_models() {
1919 let client = TogetherAiClient::new();
1920 let models = client.mock_models();
1921 assert!(!models.is_empty());
1922 assert!(models[0].id.contains("llama"));
1923 }
1924
1925 #[test]
1926 fn test_together_model_to_vector() {
1927 let client = TogetherAiClient::new();
1928 let model = TogetherModel {
1929 id: "togethercomputer/llama-2-7b".to_string(),
1930 name: Some("Llama 2 7B".to_string()),
1931 display_name: Some("Llama 2 7B".to_string()),
1932 description: Some("Meta's Llama 2 model".to_string()),
1933 context_length: Some(4096),
1934 pricing: None,
1935 };
1936
1937 let vector = client.model_to_vector(&model);
1938 assert_eq!(vector.id, "together:togethercomputer/llama-2-7b");
1939 assert_eq!(vector.domain, Domain::Research);
1940 }
1941
1942 #[tokio::test]
1943 async fn test_together_list_models_mock() {
1944 let client = TogetherAiClient::new();
1945 let models = client.list_models().await;
1946 assert!(models.is_ok());
1947 assert!(!models.unwrap().is_empty());
1948 }
1949
1950 #[test]
1952 fn test_paperswithcode_client_creation() {
1953 let client = PapersWithCodeClient::new();
1954 assert_eq!(client.base_url, "https://paperswithcode.com/api/v1");
1955 }
1956
1957 #[test]
1958 fn test_paperswithcode_paper_to_vector() {
1959 let client = PapersWithCodeClient::new();
1960 let paper = PaperWithCodePaper {
1961 id: "attention-is-all-you-need".to_string(),
1962 title: "Attention Is All You Need".to_string(),
1963 abstract_text: Some("We propose the Transformer...".to_string()),
1964 url_abs: Some("https://arxiv.org/abs/1706.03762".to_string()),
1965 url_pdf: Some("https://arxiv.org/pdf/1706.03762.pdf".to_string()),
1966 published: Some(Utc::now().to_rfc3339()),
1967 authors: Some(vec![
1968 PaperAuthor {
1969 name: "Vaswani et al.".to_string(),
1970 },
1971 ]),
1972 };
1973
1974 let vector = client.paper_to_vector(&paper);
1975 assert_eq!(vector.id, "pwc:paper:attention-is-all-you-need");
1976 assert_eq!(vector.domain, Domain::Research);
1977 assert!(vector.metadata.contains_key("title"));
1978 }
1979
1980 #[test]
1981 fn test_paperswithcode_dataset_to_vector() {
1982 let client = PapersWithCodeClient::new();
1983 let dataset = PaperWithCodeDataset {
1984 id: "imagenet".to_string(),
1985 name: "ImageNet".to_string(),
1986 full_name: Some("ImageNet Large Scale Visual Recognition Challenge".to_string()),
1987 description: Some("Large-scale image dataset".to_string()),
1988 url: Some("https://image-net.org".to_string()),
1989 paper: None,
1990 };
1991
1992 let vector = client.dataset_to_vector(&dataset);
1993 assert_eq!(vector.id, "pwc:dataset:imagenet");
1994 assert_eq!(vector.domain, Domain::Research);
1995 }
1996
1997 #[tokio::test]
1998 #[ignore] async fn test_paperswithcode_search_papers_integration() {
2000 let client = PapersWithCodeClient::new();
2001 let papers = client.search_papers("transformer").await;
2002 assert!(papers.is_ok());
2003 }
2004
2005 #[test]
2007 fn test_all_clients_default() {
2008 let _hf = HuggingFaceClient::default();
2009 let _ollama = OllamaClient::default();
2010 let _replicate = ReplicateClient::default();
2011 let _together = TogetherAiClient::default();
2012 let _pwc = PapersWithCodeClient::default();
2013 }
2014
2015 #[test]
2016 fn test_custom_embedding_dimensions() {
2017 let hf = HuggingFaceClient::with_embedding_dim(512);
2018 let model = hf.mock_models("test")[0].clone();
2019 let vector = hf.model_to_vector(&model);
2020 assert_eq!(vector.embedding.len(), 512);
2021
2022 let pwc = PapersWithCodeClient::with_embedding_dim(768);
2023 let paper = PaperWithCodePaper {
2024 id: "test".to_string(),
2025 title: "Test Paper".to_string(),
2026 abstract_text: None,
2027 url_abs: None,
2028 url_pdf: None,
2029 published: None,
2030 authors: None,
2031 };
2032 let vector = pwc.paper_to_vector(&paper);
2033 assert_eq!(vector.embedding.len(), 768);
2034 }
2035}