ruvector_core/
embeddings.rs1use crate::error::{Result, RuvectorError};
28use std::sync::Arc;
29
30pub trait EmbeddingProvider: Send + Sync {
32 fn embed(&self, text: &str) -> Result<Vec<f32>>;
34
35 fn dimensions(&self) -> usize;
37
38 fn name(&self) -> &str;
40}
41
42#[derive(Debug, Clone)]
53pub struct HashEmbedding {
54 dimensions: usize,
55}
56
57impl HashEmbedding {
58 pub fn new(dimensions: usize) -> Self {
60 Self { dimensions }
61 }
62}
63
64impl EmbeddingProvider for HashEmbedding {
65 fn embed(&self, text: &str) -> Result<Vec<f32>> {
66 let mut embedding = vec![0.0; self.dimensions];
67 let bytes = text.as_bytes();
68
69 for (i, byte) in bytes.iter().enumerate() {
70 embedding[i % self.dimensions] += (*byte as f32) / 255.0;
71 }
72
73 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
75 if norm > 0.0 {
76 for val in &mut embedding {
77 *val /= norm;
78 }
79 }
80
81 Ok(embedding)
82 }
83
84 fn dimensions(&self) -> usize {
85 self.dimensions
86 }
87
88 fn name(&self) -> &str {
89 "HashEmbedding (placeholder)"
90 }
91}
92
93#[cfg(feature = "real-embeddings")]
106pub mod candle {
107 use super::*;
108
109 pub struct CandleEmbedding {
126 dimensions: usize,
127 model_id: String,
128 }
129
130 impl CandleEmbedding {
131 pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
151 Err(RuvectorError::ModelLoadError(
152 format!(
153 "Candle embedding support is a stub. Please:\n\
154 1. Use ApiEmbedding for production (recommended)\n\
155 2. Or implement CandleEmbedding for model: {}\n\
156 3. See docs for ONNX Runtime integration examples",
157 model_id
158 )
159 ))
160 }
161 }
162
163 impl EmbeddingProvider for CandleEmbedding {
164 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
165 Err(RuvectorError::ModelInferenceError(
166 "Candle embedding not implemented - use ApiEmbedding instead".to_string()
167 ))
168 }
169
170 fn dimensions(&self) -> usize {
171 self.dimensions
172 }
173
174 fn name(&self) -> &str {
175 "CandleEmbedding (stub - not implemented)"
176 }
177 }
178}
179
180#[cfg(feature = "real-embeddings")]
181pub use candle::CandleEmbedding;
182
183#[derive(Clone)]
196pub struct ApiEmbedding {
197 api_key: String,
198 endpoint: String,
199 model: String,
200 dimensions: usize,
201 client: reqwest::blocking::Client,
202}
203
204impl ApiEmbedding {
205 pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
213 Self {
214 api_key,
215 endpoint,
216 model,
217 dimensions,
218 client: reqwest::blocking::Client::new(),
219 }
220 }
221
222 pub fn openai(api_key: &str, model: &str) -> Self {
229 let dimensions = match model {
230 "text-embedding-3-large" => 3072,
231 _ => 1536, };
233
234 Self::new(
235 api_key.to_string(),
236 "https://api.openai.com/v1/embeddings".to_string(),
237 model.to_string(),
238 dimensions,
239 )
240 }
241
242 pub fn cohere(api_key: &str, model: &str) -> Self {
248 Self::new(
249 api_key.to_string(),
250 "https://api.cohere.ai/v1/embed".to_string(),
251 model.to_string(),
252 1024,
253 )
254 }
255
256 pub fn voyage(api_key: &str, model: &str) -> Self {
262 let dimensions = if model.contains("large") { 1536 } else { 1024 };
263
264 Self::new(
265 api_key.to_string(),
266 "https://api.voyageai.com/v1/embeddings".to_string(),
267 model.to_string(),
268 dimensions,
269 )
270 }
271}
272
273impl EmbeddingProvider for ApiEmbedding {
274 fn embed(&self, text: &str) -> Result<Vec<f32>> {
275 let request_body = serde_json::json!({
276 "input": text,
277 "model": self.model,
278 });
279
280 let response = self.client
281 .post(&self.endpoint)
282 .header("Authorization", format!("Bearer {}", self.api_key))
283 .header("Content-Type", "application/json")
284 .json(&request_body)
285 .send()
286 .map_err(|e| RuvectorError::ModelInferenceError(format!("API request failed: {}", e)))?;
287
288 if !response.status().is_success() {
289 let status = response.status();
290 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
291 return Err(RuvectorError::ModelInferenceError(
292 format!("API returned error {}: {}", status, error_text)
293 ));
294 }
295
296 let response_json: serde_json::Value = response.json()
297 .map_err(|e| RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e)))?;
298
299 let embedding = if let Some(data) = response_json.get("data") {
301 data.as_array()
303 .and_then(|arr| arr.first())
304 .and_then(|obj| obj.get("embedding"))
305 .and_then(|emb| emb.as_array())
306 .ok_or_else(|| RuvectorError::ModelInferenceError(
307 "Invalid OpenAI response format".to_string()
308 ))?
309 } else if let Some(embeddings) = response_json.get("embeddings") {
310 embeddings.as_array()
312 .and_then(|arr| arr.first())
313 .and_then(|emb| emb.as_array())
314 .ok_or_else(|| RuvectorError::ModelInferenceError(
315 "Invalid Cohere response format".to_string()
316 ))?
317 } else {
318 return Err(RuvectorError::ModelInferenceError(
319 "Unknown API response format".to_string()
320 ));
321 };
322
323 let embedding_vec: Result<Vec<f32>> = embedding
324 .iter()
325 .map(|v| v.as_f64()
326 .map(|f| f as f32)
327 .ok_or_else(|| RuvectorError::ModelInferenceError(
328 "Invalid embedding value".to_string()
329 ))
330 )
331 .collect();
332
333 embedding_vec
334 }
335
336 fn dimensions(&self) -> usize {
337 self.dimensions
338 }
339
340 fn name(&self) -> &str {
341 "ApiEmbedding"
342 }
343}
344
345pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_hash_embedding() {
354 let provider = HashEmbedding::new(128);
355
356 let emb1 = provider.embed("hello world").unwrap();
357 let emb2 = provider.embed("hello world").unwrap();
358
359 assert_eq!(emb1.len(), 128);
360 assert_eq!(emb1, emb2, "Same text should produce same embedding");
361
362 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
364 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
365 }
366
367 #[test]
368 fn test_hash_embedding_different_text() {
369 let provider = HashEmbedding::new(128);
370
371 let emb1 = provider.embed("hello").unwrap();
372 let emb2 = provider.embed("world").unwrap();
373
374 assert_ne!(emb1, emb2, "Different text should produce different embeddings");
375 }
376
377 #[cfg(feature = "real-embeddings")]
378 #[test]
379 #[ignore] fn test_candle_embedding() {
381 let provider = CandleEmbedding::from_pretrained(
382 "sentence-transformers/all-MiniLM-L6-v2",
383 false
384 ).unwrap();
385
386 let embedding = provider.embed("hello world").unwrap();
387 assert_eq!(embedding.len(), 384);
388
389 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
391 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
392 }
393
394 #[test]
395 #[ignore] fn test_api_embedding_openai() {
397 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
398 let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
399
400 let embedding = provider.embed("hello world").unwrap();
401 assert_eq!(embedding.len(), 1536);
402 }
403}