semantic_search/
embedding.rs

1//! # Embedding module
2//!
3//! Embedding types, representation, conversion and calculation. Assumes little-endian byte order.
4//!
5//! ## Types
6//!
7//! - [`EmbeddingRaw`]: Raw embedding representation, alias for `[f32; 1024]`.
8//! - [`EmbeddingBytes`]: Embedding represented in bytes (little-endian), alias for `[u8; 1024 * 4]`.
9//! - [`Embedding`]: Wrapped embedding representation.
10//!
11//! ## Representation
12//!
13//! Embedding is represented as a 1024-dimensional vector of 32-bit floating point numbers. [`Embedding`] struct is a wrapper around  [`EmbeddingRaw`] (alias for `[f32; 1024]`), and provides methods for conversion and calculation.
14//!
15//! ## Conversion
16//!
17//! - [`Embedding`] can be converted from [`EmbeddingRaw`] and [`EmbeddingBytes`].
18//! - [`Embedding`] can be immutably dereferenced to [`EmbeddingRaw`] and converted to [`EmbeddingBytes`].
19//! - [`Embedding`] can be converted from `&[f32]`, `&[u8]`, `Vec<f32>` and `Vec<u8>`, but [`DimensionMismatch`](SenseError::DimensionMismatch) error is returned if the length mismatches.
20//!
21//! ## Calculation
22//!
23//! Cosine similarity between two embeddings can be calculated using [`cosine_similarity`](Embedding::cosine_similarity) method.
24
25use super::SenseError;
26use std::{convert::TryFrom, ops::Deref};
27
28/// Raw embedding representation.
29pub type EmbeddingRaw = [f32; 1024];
30
31/// Embedding represented in bytes (little-endian).
32pub type EmbeddingBytes = [u8; 1024 * 4];
33
34/// Wrapped embedding representation.
35///
36/// See [module-level documentation](crate::embedding) for more details.
37#[derive(Debug, Clone, PartialEq)]
38pub struct Embedding {
39    inner: EmbeddingRaw,
40    norm: f32,
41}
42
43// Cosine similarity calculation
44
45impl Embedding {
46    /// Calculate cosine similarity between two embeddings.
47    #[must_use]
48    pub fn cosine_similarity(&self, other: &Self) -> f32 {
49        let dot_product = self
50            .iter()
51            .zip(other.iter())
52            .map(|(a, b)| a * b)
53            .sum::<f32>();
54        dot_product / (self.norm * other.norm)
55    }
56}
57
58impl Default for Embedding {
59    fn default() -> Self {
60        Self {
61            inner: [0.0; 1024],
62            norm: 0.0,
63        }
64    }
65}
66
67// Convertion
68
69impl From<EmbeddingRaw> for Embedding {
70    /// Convert `[f32; 1024]` to `Embedding`.
71    fn from(inner: EmbeddingRaw) -> Self {
72        let norm = inner.iter().map(|a| a * a).sum::<f32>().sqrt();
73        Self { inner, norm }
74    }
75}
76
77impl From<EmbeddingBytes> for Embedding {
78    /// Convert 1024 * 4 bytes to `Embedding` (little-endian).
79    fn from(bytes: EmbeddingBytes) -> Self {
80        let mut embedding = [0.0; 1024];
81        bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
82            let f = f32::from_le_bytes(chunk.try_into().unwrap()); // Safe to unwrap, as we know the length is 4
83            embedding[i] = f;
84        });
85        Self::from(embedding)
86    }
87}
88
89impl From<Embedding> for EmbeddingBytes {
90    /// Convert `Embedding` to 1024 * 4 bytes (little-endian).
91    fn from(embedding: Embedding) -> Self {
92        let mut bytes = [0; 1024 * 4];
93        bytes
94            .chunks_exact_mut(4)
95            .enumerate()
96            .for_each(|(i, chunk)| {
97                let f = embedding[i];
98                chunk.copy_from_slice(&f.to_le_bytes());
99            });
100        bytes
101    }
102}
103
104impl TryFrom<&[f32]> for Embedding {
105    type Error = SenseError;
106
107    /// Convert `&[f32]` to `Embedding`.
108    ///
109    /// # Errors
110    ///
111    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input slice is not 1024.
112    fn try_from(value: &[f32]) -> Result<Self, Self::Error> {
113        let embedding: EmbeddingRaw = value.try_into()?;
114        Ok(Self::from(embedding))
115    }
116}
117
118impl TryFrom<&[u8]> for Embedding {
119    type Error = SenseError;
120
121    /// Convert `&[u8]` to `Embedding`.
122    ///
123    /// # Errors
124    ///
125    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input slice is not 1024 * 4.
126    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
127        let bytes: EmbeddingBytes = value.try_into()?;
128        Ok(Self::from(bytes))
129    }
130}
131
132impl TryFrom<Vec<f32>> for Embedding {
133    type Error = SenseError;
134
135    /// Convert `Vec<f32>` to `Embedding`.
136    ///
137    /// # Errors
138    ///
139    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input vector is not 1024.
140    fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
141        let embedding: EmbeddingRaw = value.try_into()?;
142        Ok(Self::from(embedding))
143    }
144}
145
146impl TryFrom<Vec<u8>> for Embedding {
147    type Error = SenseError;
148
149    /// Convert `Vec<u8>` to `Embedding`.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input vector is not 1024 * 4.
154    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
155        let bytes: EmbeddingBytes = value.try_into()?;
156        Ok(Self::from(bytes))
157    }
158}
159
160// Implement `Deref` for `Embedding`
161
162impl Deref for Embedding {
163    type Target = EmbeddingRaw;
164
165    fn deref(&self) -> &Self::Target {
166        &self.inner
167    }
168}
169
170// Should not mutate the inner representation, since `norm` is cached based on it
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    const EMBEDDING_FLOAT: f32 = 1.14; // 0x3F91EB85
177    const EMBEDDING_CHUNK: [u8; 4] = [0x85, 0xEB, 0x91, 0x3F];
178
179    #[test]
180    #[allow(clippy::float_cmp, reason = "They should be equal exactly")]
181    fn embedding_from_bytes() {
182        let mut bytes = [0; 1024 * 4];
183        bytes.chunks_exact_mut(4).for_each(|chunk| {
184            chunk.copy_from_slice(&EMBEDDING_CHUNK);
185        });
186
187        let embedding = Embedding::from(bytes);
188        embedding
189            .iter()
190            .for_each(|&f| assert_eq!(f, EMBEDDING_FLOAT));
191    }
192
193    #[test]
194    fn bytes_from_embedding() {
195        let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
196        let bytes = EmbeddingBytes::from(embedding);
197
198        bytes.chunks_exact(4).for_each(|chunk| {
199            assert_eq!(chunk, EMBEDDING_CHUNK);
200        });
201    }
202
203    #[test]
204    fn similar_to_self() {
205        let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
206        let similarity = embedding.cosine_similarity(&embedding);
207        let delta = (similarity - 1.0).abs();
208        // Approximate equality
209        assert!(delta <= f32::EPSILON);
210    }
211}