ruvector_core/
embeddings.rs1use crate::error::Result;
28#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
29use crate::error::RuvectorError;
30use std::sync::Arc;
31
32pub trait EmbeddingProvider: Send + Sync {
34 fn embed(&self, text: &str) -> Result<Vec<f32>>;
36
37 fn dimensions(&self) -> usize;
39
40 fn name(&self) -> &str;
42}
43
44#[derive(Debug, Clone)]
55pub struct HashEmbedding {
56 dimensions: usize,
57}
58
59impl HashEmbedding {
60 pub fn new(dimensions: usize) -> Self {
62 Self { dimensions }
63 }
64}
65
66impl EmbeddingProvider for HashEmbedding {
67 fn embed(&self, text: &str) -> Result<Vec<f32>> {
68 let mut embedding = vec![0.0; self.dimensions];
69 let bytes = text.as_bytes();
70
71 for (i, byte) in bytes.iter().enumerate() {
72 embedding[i % self.dimensions] += (*byte as f32) / 255.0;
73 }
74
75 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
77 if norm > 0.0 {
78 for val in &mut embedding {
79 *val /= norm;
80 }
81 }
82
83 Ok(embedding)
84 }
85
86 fn dimensions(&self) -> usize {
87 self.dimensions
88 }
89
90 fn name(&self) -> &str {
91 "HashEmbedding (placeholder)"
92 }
93}
94
95#[cfg(feature = "real-embeddings")]
108pub mod candle {
109 use super::*;
110
111 pub struct CandleEmbedding {
128 dimensions: usize,
129 model_id: String,
130 }
131
132 impl CandleEmbedding {
133 pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
153 Err(RuvectorError::ModelLoadError(format!(
154 "Candle embedding support is a stub. Please:\n\
155 1. Use ApiEmbedding for production (recommended)\n\
156 2. Or implement CandleEmbedding for model: {}\n\
157 3. See docs for ONNX Runtime integration examples",
158 model_id
159 )))
160 }
161 }
162
163 impl EmbeddingProvider for CandleEmbedding {
164 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
165 Err(RuvectorError::ModelInferenceError(
166 "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
167 ))
168 }
169
170 fn dimensions(&self) -> usize {
171 self.dimensions
172 }
173
174 fn name(&self) -> &str {
175 "CandleEmbedding (stub - not implemented)"
176 }
177 }
178}
179
180#[cfg(feature = "real-embeddings")]
181pub use candle::CandleEmbedding;
182
183#[cfg(feature = "api-embeddings")]
196#[derive(Clone)]
197pub struct ApiEmbedding {
198 api_key: String,
199 endpoint: String,
200 model: String,
201 dimensions: usize,
202 client: reqwest::blocking::Client,
203}
204
205#[cfg(feature = "api-embeddings")]
206impl ApiEmbedding {
207 pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
215 Self {
216 api_key,
217 endpoint,
218 model,
219 dimensions,
220 client: reqwest::blocking::Client::new(),
221 }
222 }
223
224 pub fn openai(api_key: &str, model: &str) -> Self {
231 let dimensions = match model {
232 "text-embedding-3-large" => 3072,
233 _ => 1536, };
235
236 Self::new(
237 api_key.to_string(),
238 "https://api.openai.com/v1/embeddings".to_string(),
239 model.to_string(),
240 dimensions,
241 )
242 }
243
244 pub fn cohere(api_key: &str, model: &str) -> Self {
250 Self::new(
251 api_key.to_string(),
252 "https://api.cohere.ai/v1/embed".to_string(),
253 model.to_string(),
254 1024,
255 )
256 }
257
258 pub fn voyage(api_key: &str, model: &str) -> Self {
264 let dimensions = if model.contains("large") { 1536 } else { 1024 };
265
266 Self::new(
267 api_key.to_string(),
268 "https://api.voyageai.com/v1/embeddings".to_string(),
269 model.to_string(),
270 dimensions,
271 )
272 }
273}
274
275#[cfg(feature = "api-embeddings")]
276impl EmbeddingProvider for ApiEmbedding {
277 fn embed(&self, text: &str) -> Result<Vec<f32>> {
278 let request_body = serde_json::json!({
279 "input": text,
280 "model": self.model,
281 });
282
283 let response = self
284 .client
285 .post(&self.endpoint)
286 .header("Authorization", format!("Bearer {}", self.api_key))
287 .header("Content-Type", "application/json")
288 .json(&request_body)
289 .send()
290 .map_err(|e| {
291 RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
292 })?;
293
294 if !response.status().is_success() {
295 let status = response.status();
296 let error_text = response
297 .text()
298 .unwrap_or_else(|_| "Unknown error".to_string());
299 return Err(RuvectorError::ModelInferenceError(format!(
300 "API returned error {}: {}",
301 status, error_text
302 )));
303 }
304
305 let response_json: serde_json::Value = response.json().map_err(|e| {
306 RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
307 })?;
308
309 let embedding = if let Some(data) = response_json.get("data") {
311 data.as_array()
313 .and_then(|arr| arr.first())
314 .and_then(|obj| obj.get("embedding"))
315 .and_then(|emb| emb.as_array())
316 .ok_or_else(|| {
317 RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
318 })?
319 } else if let Some(embeddings) = response_json.get("embeddings") {
320 embeddings
322 .as_array()
323 .and_then(|arr| arr.first())
324 .and_then(|emb| emb.as_array())
325 .ok_or_else(|| {
326 RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
327 })?
328 } else {
329 return Err(RuvectorError::ModelInferenceError(
330 "Unknown API response format".to_string(),
331 ));
332 };
333
334 let embedding_vec: Result<Vec<f32>> = embedding
335 .iter()
336 .map(|v| {
337 v.as_f64().map(|f| f as f32).ok_or_else(|| {
338 RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
339 })
340 })
341 .collect();
342
343 embedding_vec
344 }
345
346 fn dimensions(&self) -> usize {
347 self.dimensions
348 }
349
350 fn name(&self) -> &str {
351 "ApiEmbedding"
352 }
353}
354
355pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_hash_embedding() {
364 let provider = HashEmbedding::new(128);
365
366 let emb1 = provider.embed("hello world").unwrap();
367 let emb2 = provider.embed("hello world").unwrap();
368
369 assert_eq!(emb1.len(), 128);
370 assert_eq!(emb1, emb2, "Same text should produce same embedding");
371
372 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
374 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
375 }
376
377 #[test]
378 fn test_hash_embedding_different_text() {
379 let provider = HashEmbedding::new(128);
380
381 let emb1 = provider.embed("hello").unwrap();
382 let emb2 = provider.embed("world").unwrap();
383
384 assert_ne!(
385 emb1, emb2,
386 "Different text should produce different embeddings"
387 );
388 }
389
390 #[cfg(feature = "real-embeddings")]
391 #[test]
392 #[ignore] fn test_candle_embedding() {
394 let provider =
395 CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
396 .unwrap();
397
398 let embedding = provider.embed("hello world").unwrap();
399 assert_eq!(embedding.len(), 384);
400
401 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
403 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
404 }
405
406 #[test]
407 #[ignore] fn test_api_embedding_openai() {
409 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
410 let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
411
412 let embedding = provider.embed("hello world").unwrap();
413 assert_eq!(embedding.len(), 1536);
414 }
415}