ruvector_core/
embeddings.rs1use crate::error::{Result, RuvectorError};
28use std::sync::Arc;
29
30pub trait EmbeddingProvider: Send + Sync {
32 fn embed(&self, text: &str) -> Result<Vec<f32>>;
34
35 fn dimensions(&self) -> usize;
37
38 fn name(&self) -> &str;
40}
41
42#[derive(Debug, Clone)]
53pub struct HashEmbedding {
54 dimensions: usize,
55}
56
57impl HashEmbedding {
58 pub fn new(dimensions: usize) -> Self {
60 Self { dimensions }
61 }
62}
63
64impl EmbeddingProvider for HashEmbedding {
65 fn embed(&self, text: &str) -> Result<Vec<f32>> {
66 let mut embedding = vec![0.0; self.dimensions];
67 let bytes = text.as_bytes();
68
69 for (i, byte) in bytes.iter().enumerate() {
70 embedding[i % self.dimensions] += (*byte as f32) / 255.0;
71 }
72
73 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
75 if norm > 0.0 {
76 for val in &mut embedding {
77 *val /= norm;
78 }
79 }
80
81 Ok(embedding)
82 }
83
84 fn dimensions(&self) -> usize {
85 self.dimensions
86 }
87
88 fn name(&self) -> &str {
89 "HashEmbedding (placeholder)"
90 }
91}
92
93#[cfg(feature = "real-embeddings")]
106pub mod candle {
107 use super::*;
108
109 pub struct CandleEmbedding {
126 dimensions: usize,
127 model_id: String,
128 }
129
130 impl CandleEmbedding {
131 pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
151 Err(RuvectorError::ModelLoadError(
152 format!(
153 "Candle embedding support is a stub. Please:\n\
154 1. Use ApiEmbedding for production (recommended)\n\
155 2. Or implement CandleEmbedding for model: {}\n\
156 3. See docs for ONNX Runtime integration examples",
157 model_id
158 )
159 ))
160 }
161 }
162
163 impl EmbeddingProvider for CandleEmbedding {
164 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
165 Err(RuvectorError::ModelInferenceError(
166 "Candle embedding not implemented - use ApiEmbedding instead".to_string()
167 ))
168 }
169
170 fn dimensions(&self) -> usize {
171 self.dimensions
172 }
173
174 fn name(&self) -> &str {
175 "CandleEmbedding (stub - not implemented)"
176 }
177 }
178}
179
180#[cfg(feature = "real-embeddings")]
181pub use candle::CandleEmbedding;
182
183#[cfg(feature = "api-embeddings")]
196#[derive(Clone)]
197pub struct ApiEmbedding {
198 api_key: String,
199 endpoint: String,
200 model: String,
201 dimensions: usize,
202 client: reqwest::blocking::Client,
203}
204
205#[cfg(feature = "api-embeddings")]
206impl ApiEmbedding {
207 pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
215 Self {
216 api_key,
217 endpoint,
218 model,
219 dimensions,
220 client: reqwest::blocking::Client::new(),
221 }
222 }
223
224 pub fn openai(api_key: &str, model: &str) -> Self {
231 let dimensions = match model {
232 "text-embedding-3-large" => 3072,
233 _ => 1536, };
235
236 Self::new(
237 api_key.to_string(),
238 "https://api.openai.com/v1/embeddings".to_string(),
239 model.to_string(),
240 dimensions,
241 )
242 }
243
244 pub fn cohere(api_key: &str, model: &str) -> Self {
250 Self::new(
251 api_key.to_string(),
252 "https://api.cohere.ai/v1/embed".to_string(),
253 model.to_string(),
254 1024,
255 )
256 }
257
258 pub fn voyage(api_key: &str, model: &str) -> Self {
264 let dimensions = if model.contains("large") { 1536 } else { 1024 };
265
266 Self::new(
267 api_key.to_string(),
268 "https://api.voyageai.com/v1/embeddings".to_string(),
269 model.to_string(),
270 dimensions,
271 )
272 }
273}
274
275#[cfg(feature = "api-embeddings")]
276impl EmbeddingProvider for ApiEmbedding {
277 fn embed(&self, text: &str) -> Result<Vec<f32>> {
278 let request_body = serde_json::json!({
279 "input": text,
280 "model": self.model,
281 });
282
283 let response = self.client
284 .post(&self.endpoint)
285 .header("Authorization", format!("Bearer {}", self.api_key))
286 .header("Content-Type", "application/json")
287 .json(&request_body)
288 .send()
289 .map_err(|e| RuvectorError::ModelInferenceError(format!("API request failed: {}", e)))?;
290
291 if !response.status().is_success() {
292 let status = response.status();
293 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
294 return Err(RuvectorError::ModelInferenceError(
295 format!("API returned error {}: {}", status, error_text)
296 ));
297 }
298
299 let response_json: serde_json::Value = response.json()
300 .map_err(|e| RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e)))?;
301
302 let embedding = if let Some(data) = response_json.get("data") {
304 data.as_array()
306 .and_then(|arr| arr.first())
307 .and_then(|obj| obj.get("embedding"))
308 .and_then(|emb| emb.as_array())
309 .ok_or_else(|| RuvectorError::ModelInferenceError(
310 "Invalid OpenAI response format".to_string()
311 ))?
312 } else if let Some(embeddings) = response_json.get("embeddings") {
313 embeddings.as_array()
315 .and_then(|arr| arr.first())
316 .and_then(|emb| emb.as_array())
317 .ok_or_else(|| RuvectorError::ModelInferenceError(
318 "Invalid Cohere response format".to_string()
319 ))?
320 } else {
321 return Err(RuvectorError::ModelInferenceError(
322 "Unknown API response format".to_string()
323 ));
324 };
325
326 let embedding_vec: Result<Vec<f32>> = embedding
327 .iter()
328 .map(|v| v.as_f64()
329 .map(|f| f as f32)
330 .ok_or_else(|| RuvectorError::ModelInferenceError(
331 "Invalid embedding value".to_string()
332 ))
333 )
334 .collect();
335
336 embedding_vec
337 }
338
339 fn dimensions(&self) -> usize {
340 self.dimensions
341 }
342
343 fn name(&self) -> &str {
344 "ApiEmbedding"
345 }
346}
347
348pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_hash_embedding() {
357 let provider = HashEmbedding::new(128);
358
359 let emb1 = provider.embed("hello world").unwrap();
360 let emb2 = provider.embed("hello world").unwrap();
361
362 assert_eq!(emb1.len(), 128);
363 assert_eq!(emb1, emb2, "Same text should produce same embedding");
364
365 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
367 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
368 }
369
370 #[test]
371 fn test_hash_embedding_different_text() {
372 let provider = HashEmbedding::new(128);
373
374 let emb1 = provider.embed("hello").unwrap();
375 let emb2 = provider.embed("world").unwrap();
376
377 assert_ne!(emb1, emb2, "Different text should produce different embeddings");
378 }
379
380 #[cfg(feature = "real-embeddings")]
381 #[test]
382 #[ignore] fn test_candle_embedding() {
384 let provider = CandleEmbedding::from_pretrained(
385 "sentence-transformers/all-MiniLM-L6-v2",
386 false
387 ).unwrap();
388
389 let embedding = provider.embed("hello world").unwrap();
390 assert_eq!(embedding.len(), 384);
391
392 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
394 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
395 }
396
397 #[test]
398 #[ignore] fn test_api_embedding_openai() {
400 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
401 let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
402
403 let embedding = provider.embed("hello world").unwrap();
404 assert_eq!(embedding.len(), 1536);
405 }
406}