Skip to main content

synwire_core/embeddings/
fake.rs

1//! Deterministic fake embeddings for testing.
2
3use crate::BoxFuture;
4use crate::error::SynwireError;
5
6use super::traits::Embeddings;
7
8/// Deterministic fake embeddings for testing.
9///
10/// Generates vectors by hashing the input text. The same text always produces
11/// the same embedding vector, making tests reproducible.
12///
13/// # Examples
14///
15/// ```
16/// use synwire_core::embeddings::FakeEmbeddings;
17/// let embeddings = FakeEmbeddings::new(128);
18/// ```
19pub struct FakeEmbeddings {
20    dimensions: usize,
21}
22
23impl FakeEmbeddings {
24    /// Creates a new `FakeEmbeddings` with the given vector dimensionality.
25    pub const fn new(dimensions: usize) -> Self {
26        Self { dimensions }
27    }
28
29    /// Generates a deterministic embedding vector for a given text.
30    fn embed_text(&self, text: &str) -> Vec<f32> {
31        let mut vector = Vec::with_capacity(self.dimensions);
32        for i in 0..self.dimensions {
33            #[allow(clippy::cast_possible_truncation)]
34            let hash = text.bytes().enumerate().fold(0u32, |acc, (j, b)| {
35                acc.wrapping_add(
36                    u32::from(b)
37                        .wrapping_mul((j + 1) as u32)
38                        .wrapping_mul((i + 1) as u32),
39                )
40            });
41            // Map to [0, 1) range
42            #[allow(clippy::cast_precision_loss)]
43            let val = (hash % 10_000) as f32 / 10_000.0;
44            vector.push(val);
45        }
46
47        // Normalize to unit length
48        #[allow(clippy::cast_precision_loss)]
49        let magnitude = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
50        if magnitude > f32::EPSILON {
51            for v in &mut vector {
52                *v /= magnitude;
53            }
54        }
55
56        vector
57    }
58}
59
60impl Embeddings for FakeEmbeddings {
61    fn embed_documents<'a>(
62        &'a self,
63        texts: &'a [String],
64    ) -> BoxFuture<'a, Result<Vec<Vec<f32>>, SynwireError>> {
65        Box::pin(async move { Ok(texts.iter().map(|t| self.embed_text(t)).collect()) })
66    }
67
68    fn embed_query<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, SynwireError>> {
69        Box::pin(async move { Ok(self.embed_text(text)) })
70    }
71}
72
73#[cfg(test)]
74#[allow(clippy::unwrap_used)]
75mod tests {
76    use super::*;
77
78    #[tokio::test]
79    async fn fake_embeddings_returns_consistent_dimensions() {
80        let embeddings = FakeEmbeddings::new(64);
81        let texts = vec!["hello".into(), "world".into(), "foo bar".into()];
82        let result = embeddings.embed_documents(&texts).await.unwrap();
83        assert_eq!(result.len(), 3);
84        for vec in &result {
85            assert_eq!(vec.len(), 64);
86        }
87    }
88
89    #[tokio::test]
90    async fn embed_query_returns_single_vector() {
91        let embeddings = FakeEmbeddings::new(32);
92        let result = embeddings.embed_query("test query").await.unwrap();
93        assert_eq!(result.len(), 32);
94    }
95
96    #[tokio::test]
97    async fn fake_embeddings_are_deterministic() {
98        let embeddings = FakeEmbeddings::new(16);
99        let v1 = embeddings.embed_query("hello").await.unwrap();
100        let v2 = embeddings.embed_query("hello").await.unwrap();
101        assert_eq!(v1, v2);
102    }
103
104    #[tokio::test]
105    async fn fake_embeddings_are_normalized() {
106        let embeddings = FakeEmbeddings::new(64);
107        let v = embeddings.embed_query("normalize me").await.unwrap();
108        let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
109        assert!((magnitude - 1.0).abs() < 1e-4, "magnitude = {magnitude}");
110    }
111
112    #[tokio::test]
113    async fn different_texts_produce_different_embeddings() {
114        let embeddings = FakeEmbeddings::new(16);
115        let v1 = embeddings.embed_query("alpha").await.unwrap();
116        let v2 = embeddings.embed_query("beta").await.unwrap();
117        assert_ne!(v1, v2);
118    }
119}