synaptic_huggingface/
lib.rs1use async_trait::async_trait;
2use synaptic_core::{Embeddings, SynapticError};
3
4#[derive(Debug, Clone)]
5pub struct HuggingFaceEmbeddingsConfig {
6 pub model: String,
7 pub api_key: Option<String>,
8 pub base_url: String,
9 pub wait_for_model: bool,
10}
11
12impl HuggingFaceEmbeddingsConfig {
13 pub fn new(model: impl Into<String>) -> Self {
14 Self {
15 model: model.into(),
16 api_key: None,
17 base_url: "https://api-inference.huggingface.co/models".to_string(),
18 wait_for_model: true,
19 }
20 }
21 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
22 self.api_key = Some(api_key.into());
23 self
24 }
25 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
26 self.base_url = base_url.into();
27 self
28 }
29 pub fn with_wait_for_model(mut self, wait: bool) -> Self {
30 self.wait_for_model = wait;
31 self
32 }
33}
34
35pub struct HuggingFaceEmbeddings {
36 config: HuggingFaceEmbeddingsConfig,
37 client: reqwest::Client,
38}
39
40impl HuggingFaceEmbeddings {
41 pub fn new(config: HuggingFaceEmbeddingsConfig) -> Self {
42 Self {
43 config,
44 client: reqwest::Client::new(),
45 }
46 }
47 pub fn with_client(config: HuggingFaceEmbeddingsConfig, client: reqwest::Client) -> Self {
48 Self { config, client }
49 }
50
51 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
52 if texts.is_empty() {
53 return Ok(Vec::new());
54 }
55 let url = format!("{}/{}", self.config.base_url, self.config.model);
56 let body = serde_json::json!({ "inputs": texts });
57 let mut request = self
58 .client
59 .post(&url)
60 .header("Content-Type", "application/json");
61 if let Some(ref key) = self.config.api_key {
62 request = request.header("Authorization", format!("Bearer {key}"));
63 }
64 if self.config.wait_for_model {
65 request = request.header("x-wait-for-model", "true");
66 }
67 let response = request
68 .json(&body)
69 .send()
70 .await
71 .map_err(|e| SynapticError::Embedding(format!("HuggingFace request: {e}")))?;
72 let status = response.status();
73 if status.is_client_error() || status.is_server_error() {
74 let code = status.as_u16();
75 let text = response.text().await.unwrap_or_default();
76 return Err(SynapticError::Embedding(format!(
77 "HuggingFace API error ({code}): {text}"
78 )));
79 }
80 let resp: serde_json::Value = response
81 .json()
82 .await
83 .map_err(|e| SynapticError::Embedding(format!("HuggingFace parse: {e}")))?;
84 parse_hf_response(&resp)
85 }
86}
87
88fn parse_hf_response(resp: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
89 let array = if let Some(arr) = resp.as_array() {
90 arr
91 } else if let Some(arr) = resp.get("embeddings").and_then(|e| e.as_array()) {
92 arr
93 } else {
94 return Err(SynapticError::Embedding(
95 "unexpected HuggingFace response format".to_string(),
96 ));
97 };
98 let mut result = Vec::with_capacity(array.len());
99 for item in array {
100 let embedding: Vec<f32> = item
101 .as_array()
102 .ok_or_else(|| SynapticError::Embedding("embedding item is not array".to_string()))?
103 .iter()
104 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
105 .collect();
106 result.push(embedding);
107 }
108 Ok(result)
109}
110
111#[async_trait]
112impl Embeddings for HuggingFaceEmbeddings {
113 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
114 self.embed_batch(texts).await
115 }
116 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
117 let mut results = self.embed_batch(&[text]).await?;
118 results
119 .pop()
120 .ok_or_else(|| SynapticError::Embedding("empty HuggingFace response".to_string()))
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn config_defaults() {
130 let c = HuggingFaceEmbeddingsConfig::new("BAAI/bge-small-en-v1.5");
131 assert_eq!(c.model, "BAAI/bge-small-en-v1.5");
132 }
133
134 #[test]
135 fn config_builder() {
136 let c = HuggingFaceEmbeddingsConfig::new("model")
137 .with_api_key("hf_test")
138 .with_wait_for_model(false);
139 assert_eq!(c.api_key, Some("hf_test".to_string()));
140 }
141
142 #[test]
143 fn parse_direct_array() {
144 let resp = serde_json::json!([[0.1_f32, 0.2_f32]]);
145 let result = parse_hf_response(&resp).unwrap();
146 assert_eq!(result.len(), 1);
147 }
148}