ruvector_core/
embeddings.rs1use crate::error::{Result, RuvectorError};
28use std::sync::Arc;
29
30pub trait EmbeddingProvider: Send + Sync {
32 fn embed(&self, text: &str) -> Result<Vec<f32>>;
34
35 fn dimensions(&self) -> usize;
37
38 fn name(&self) -> &str;
40}
41
42#[derive(Debug, Clone)]
53pub struct HashEmbedding {
54 dimensions: usize,
55}
56
57impl HashEmbedding {
58 pub fn new(dimensions: usize) -> Self {
60 Self { dimensions }
61 }
62}
63
64impl EmbeddingProvider for HashEmbedding {
65 fn embed(&self, text: &str) -> Result<Vec<f32>> {
66 let mut embedding = vec![0.0; self.dimensions];
67 let bytes = text.as_bytes();
68
69 for (i, byte) in bytes.iter().enumerate() {
70 embedding[i % self.dimensions] += (*byte as f32) / 255.0;
71 }
72
73 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
75 if norm > 0.0 {
76 for val in &mut embedding {
77 *val /= norm;
78 }
79 }
80
81 Ok(embedding)
82 }
83
84 fn dimensions(&self) -> usize {
85 self.dimensions
86 }
87
88 fn name(&self) -> &str {
89 "HashEmbedding (placeholder)"
90 }
91}
92
93#[cfg(feature = "real-embeddings")]
106pub mod candle {
107 use super::*;
108
109 pub struct CandleEmbedding {
126 dimensions: usize,
127 model_id: String,
128 }
129
130 impl CandleEmbedding {
131 pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
151 Err(RuvectorError::ModelLoadError(format!(
152 "Candle embedding support is a stub. Please:\n\
153 1. Use ApiEmbedding for production (recommended)\n\
154 2. Or implement CandleEmbedding for model: {}\n\
155 3. See docs for ONNX Runtime integration examples",
156 model_id
157 )))
158 }
159 }
160
161 impl EmbeddingProvider for CandleEmbedding {
162 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
163 Err(RuvectorError::ModelInferenceError(
164 "Candle embedding not implemented - use ApiEmbedding instead".to_string(),
165 ))
166 }
167
168 fn dimensions(&self) -> usize {
169 self.dimensions
170 }
171
172 fn name(&self) -> &str {
173 "CandleEmbedding (stub - not implemented)"
174 }
175 }
176}
177
178#[cfg(feature = "real-embeddings")]
179pub use candle::CandleEmbedding;
180
181#[cfg(feature = "api-embeddings")]
194#[derive(Clone)]
195pub struct ApiEmbedding {
196 api_key: String,
197 endpoint: String,
198 model: String,
199 dimensions: usize,
200 client: reqwest::blocking::Client,
201}
202
203#[cfg(feature = "api-embeddings")]
204impl ApiEmbedding {
205 pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
213 Self {
214 api_key,
215 endpoint,
216 model,
217 dimensions,
218 client: reqwest::blocking::Client::new(),
219 }
220 }
221
222 pub fn openai(api_key: &str, model: &str) -> Self {
229 let dimensions = match model {
230 "text-embedding-3-large" => 3072,
231 _ => 1536, };
233
234 Self::new(
235 api_key.to_string(),
236 "https://api.openai.com/v1/embeddings".to_string(),
237 model.to_string(),
238 dimensions,
239 )
240 }
241
242 pub fn cohere(api_key: &str, model: &str) -> Self {
248 Self::new(
249 api_key.to_string(),
250 "https://api.cohere.ai/v1/embed".to_string(),
251 model.to_string(),
252 1024,
253 )
254 }
255
256 pub fn voyage(api_key: &str, model: &str) -> Self {
262 let dimensions = if model.contains("large") { 1536 } else { 1024 };
263
264 Self::new(
265 api_key.to_string(),
266 "https://api.voyageai.com/v1/embeddings".to_string(),
267 model.to_string(),
268 dimensions,
269 )
270 }
271}
272
273#[cfg(feature = "api-embeddings")]
274impl EmbeddingProvider for ApiEmbedding {
275 fn embed(&self, text: &str) -> Result<Vec<f32>> {
276 let request_body = serde_json::json!({
277 "input": text,
278 "model": self.model,
279 });
280
281 let response = self
282 .client
283 .post(&self.endpoint)
284 .header("Authorization", format!("Bearer {}", self.api_key))
285 .header("Content-Type", "application/json")
286 .json(&request_body)
287 .send()
288 .map_err(|e| {
289 RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
290 })?;
291
292 if !response.status().is_success() {
293 let status = response.status();
294 let error_text = response
295 .text()
296 .unwrap_or_else(|_| "Unknown error".to_string());
297 return Err(RuvectorError::ModelInferenceError(format!(
298 "API returned error {}: {}",
299 status, error_text
300 )));
301 }
302
303 let response_json: serde_json::Value = response.json().map_err(|e| {
304 RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
305 })?;
306
307 let embedding = if let Some(data) = response_json.get("data") {
309 data.as_array()
311 .and_then(|arr| arr.first())
312 .and_then(|obj| obj.get("embedding"))
313 .and_then(|emb| emb.as_array())
314 .ok_or_else(|| {
315 RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
316 })?
317 } else if let Some(embeddings) = response_json.get("embeddings") {
318 embeddings
320 .as_array()
321 .and_then(|arr| arr.first())
322 .and_then(|emb| emb.as_array())
323 .ok_or_else(|| {
324 RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
325 })?
326 } else {
327 return Err(RuvectorError::ModelInferenceError(
328 "Unknown API response format".to_string(),
329 ));
330 };
331
332 let embedding_vec: Result<Vec<f32>> = embedding
333 .iter()
334 .map(|v| {
335 v.as_f64().map(|f| f as f32).ok_or_else(|| {
336 RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
337 })
338 })
339 .collect();
340
341 embedding_vec
342 }
343
344 fn dimensions(&self) -> usize {
345 self.dimensions
346 }
347
348 fn name(&self) -> &str {
349 "ApiEmbedding"
350 }
351}
352
353pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_hash_embedding() {
362 let provider = HashEmbedding::new(128);
363
364 let emb1 = provider.embed("hello world").unwrap();
365 let emb2 = provider.embed("hello world").unwrap();
366
367 assert_eq!(emb1.len(), 128);
368 assert_eq!(emb1, emb2, "Same text should produce same embedding");
369
370 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
372 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
373 }
374
375 #[test]
376 fn test_hash_embedding_different_text() {
377 let provider = HashEmbedding::new(128);
378
379 let emb1 = provider.embed("hello").unwrap();
380 let emb2 = provider.embed("world").unwrap();
381
382 assert_ne!(
383 emb1, emb2,
384 "Different text should produce different embeddings"
385 );
386 }
387
388 #[cfg(feature = "real-embeddings")]
389 #[test]
390 #[ignore] fn test_candle_embedding() {
392 let provider =
393 CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
394 .unwrap();
395
396 let embedding = provider.embed("hello world").unwrap();
397 assert_eq!(embedding.len(), 384);
398
399 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
401 assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
402 }
403
404 #[test]
405 #[ignore] fn test_api_embedding_openai() {
407 let api_key = std::env::var("OPENAI_API_KEY").unwrap();
408 let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
409
410 let embedding = provider.embed("hello world").unwrap();
411 assert_eq!(embedding.len(), 1536);
412 }
413}