Skip to main content

zeph_common/
math.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Mathematical utilities for vector operations.
5//!
6//! This module provides general-purpose vector math as well as the
7//! [`EmbeddingVector<State>`] typestate wrapper that encodes L2-normalization at the
8//! type level. Use [`EmbeddingVector::<Normalized>`] as the required parameter type on
9//! functions that feed vectors directly into Qdrant cosine-distance searches.
10
11use std::marker::PhantomData;
12
13// ── Typestate markers ────────────────────────────────────────────────────────
14
15/// Typestate marker indicating that an [`EmbeddingVector`] has been L2-normalized
16/// to unit length.
17///
18/// This marker cannot be constructed outside this module — it can only be created
19/// by [`EmbeddingVector::normalize`] or the trust-caller constructor
20/// [`EmbeddingVector::<Normalized>::new_normalized`].
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct Normalized(());
23
24/// Typestate marker indicating that an [`EmbeddingVector`] has **not** been
25/// normalized yet (raw output from a model or loaded from storage).
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub struct Unnormalized(());
28
29// ── EmbeddingVector ──────────────────────────────────────────────────────────
30
31/// An embedding vector tagged with a normalization-state marker.
32///
33/// The type parameter encodes whether the vector is L2-normalized:
34///
35/// - `EmbeddingVector<Unnormalized>` — raw model output; must be normalized before
36///   passing to cosine-distance Qdrant searches.
37/// - `EmbeddingVector<Normalized>` — unit-length vector, safe to pass directly to
38///   Qdrant gRPC cosine queries.
39///
40/// Using [`Normalized`] as a required parameter type at the Qdrant search boundary
41/// turns dimension/normalization mismatches into compile-time errors rather than
42/// silent near-zero similarity scores (see bugs #3421, #3382, #3420, #3422).
43///
44/// # Construction
45///
46/// ```
47/// use zeph_common::math::{EmbeddingVector, Normalized, Unnormalized};
48///
49/// // Wrap a raw model vector and normalize it.
50/// let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
51/// let normalized = raw.normalize();
52/// let slice = normalized.as_slice();
53/// // A normalized vector has unit L2 length.
54/// let norm: f32 = slice.iter().map(|x| x * x).sum::<f32>().sqrt();
55/// assert!((norm - 1.0).abs() < 1e-6);
56///
57/// // Trust-caller constructor for models that always return unit vectors.
58/// let trusted = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
59/// assert_eq!(trusted.as_slice(), &[0.6_f32, 0.8]);
60/// ```
61#[derive(Debug, Clone)]
62pub struct EmbeddingVector<State> {
63    inner: Vec<f32>,
64    _state: PhantomData<State>,
65}
66
67impl EmbeddingVector<Unnormalized> {
68    /// Wrap a raw embedding vector from a model or storage.
69    ///
70    /// The returned vector is tagged `Unnormalized`. Call [`normalize`](Self::normalize)
71    /// before passing it to functions that require [`Normalized`].
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// use zeph_common::math::{EmbeddingVector, Unnormalized};
77    ///
78    /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 0.0]);
79    /// assert_eq!(v.as_slice(), &[1.0_f32, 0.0]);
80    /// ```
81    #[must_use]
82    pub fn new(inner: Vec<f32>) -> Self {
83        Self {
84            inner,
85            _state: PhantomData,
86        }
87    }
88
89    /// L2-normalize this vector and return an [`EmbeddingVector<Normalized>`].
90    ///
91    /// If the vector is a zero vector (L2 norm is zero), all elements are set to zero
92    /// to avoid division by zero; the result is technically invalid for cosine search
93    /// but is safe and consistent with the behavior of [`cosine_similarity`].
94    ///
95    /// # Examples
96    ///
97    /// ```
98    /// use zeph_common::math::{EmbeddingVector, Unnormalized};
99    ///
100    /// let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
101    /// let norm = raw.normalize();
102    /// let sum_sq: f32 = norm.as_slice().iter().map(|x| x * x).sum();
103    /// assert!((sum_sq - 1.0).abs() < 1e-6, "must be unit length");
104    /// ```
105    #[must_use]
106    pub fn normalize(self) -> EmbeddingVector<Normalized> {
107        let norm: f32 = self.inner.iter().map(|x| x * x).sum::<f32>().sqrt();
108        let normalized = if norm < f32::EPSILON {
109            self.inner
110        } else {
111            self.inner.into_iter().map(|x| x / norm).collect()
112        };
113        EmbeddingVector {
114            inner: normalized,
115            _state: PhantomData,
116        }
117    }
118}
119
120impl EmbeddingVector<Normalized> {
121    /// Construct a normalized embedding vector, trusting the caller that `inner` is
122    /// already L2-unit-length.
123    ///
124    /// Use this constructor only when the source guarantees unit-length output (e.g., a
125    /// model that always normalizes, or a vector loaded from a store known to hold
126    /// normalized data). Incorrect use does **not** cause UB but will produce wrong
127    /// cosine scores.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use zeph_common::math::{EmbeddingVector, Normalized};
133    ///
134    /// // Suppose the model always returns unit vectors.
135    /// let v = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
136    /// assert_eq!(v.as_slice(), &[0.6_f32, 0.8]);
137    /// ```
138    #[must_use]
139    pub fn new_normalized(inner: Vec<f32>) -> Self {
140        Self {
141            inner,
142            _state: PhantomData,
143        }
144    }
145}
146
147impl<State> EmbeddingVector<State> {
148    /// Return a borrowed slice of the vector elements.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// use zeph_common::math::{EmbeddingVector, Unnormalized};
154    ///
155    /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
156    /// assert_eq!(v.as_slice(), &[1.0_f32, 2.0]);
157    /// ```
158    #[must_use]
159    pub fn as_slice(&self) -> &[f32] {
160        &self.inner
161    }
162
163    /// Consume the wrapper and return the underlying `Vec<f32>`.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// use zeph_common::math::{EmbeddingVector, Unnormalized};
169    ///
170    /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
171    /// assert_eq!(v.into_inner(), vec![1.0_f32, 2.0]);
172    /// ```
173    #[must_use]
174    pub fn into_inner(self) -> Vec<f32> {
175        self.inner
176    }
177
178    /// Return the number of dimensions in this vector.
179    ///
180    /// # Examples
181    ///
182    /// ```
183    /// use zeph_common::math::{EmbeddingVector, Unnormalized};
184    ///
185    /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0, 3.0]);
186    /// assert_eq!(v.len(), 3);
187    /// ```
188    #[must_use]
189    pub fn len(&self) -> usize {
190        self.inner.len()
191    }
192
193    /// Returns `true` if the vector has no elements.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// use zeph_common::math::{EmbeddingVector, Unnormalized};
199    ///
200    /// let v = EmbeddingVector::<Unnormalized>::new(vec![]);
201    /// assert!(v.is_empty());
202    /// ```
203    #[must_use]
204    pub fn is_empty(&self) -> bool {
205        self.inner.is_empty()
206    }
207}
208
209impl From<Vec<f32>> for EmbeddingVector<Unnormalized> {
210    fn from(v: Vec<f32>) -> Self {
211        Self::new(v)
212    }
213}
214
215/// Compute cosine similarity between two equal-length f32 vectors.
216///
217/// Returns `0.0` if the vectors have different lengths, are empty, or if
218/// either vector is a zero vector.
219///
220/// Uses a single-pass loop for efficiency.
221#[inline]
222#[must_use]
223pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
224    if a.len() != b.len() || a.is_empty() {
225        return 0.0;
226    }
227    debug_assert_eq!(a.len(), b.len(), "cosine_similarity: length mismatch");
228
229    let mut dot = 0.0_f32;
230    let mut norm_a = 0.0_f32;
231    let mut norm_b = 0.0_f32;
232
233    for (x, y) in a.iter().zip(b.iter()) {
234        dot += x * y;
235        norm_a += x * x;
236        norm_b += y * y;
237    }
238
239    let denom = norm_a.sqrt() * norm_b.sqrt();
240    if denom < f32::EPSILON {
241        return 0.0;
242    }
243
244    (dot / denom).clamp(-1.0, 1.0)
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn identical_vectors() {
253        let v = vec![1.0_f32, 2.0, 3.0];
254        assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
255    }
256
257    #[test]
258    fn orthogonal_vectors() {
259        let a = vec![1.0_f32, 0.0];
260        let b = vec![0.0_f32, 1.0];
261        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
262    }
263
264    #[test]
265    fn opposite_vectors() {
266        let a = vec![1.0_f32, 0.0];
267        let b = vec![-1.0_f32, 0.0];
268        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
269    }
270
271    #[test]
272    fn zero_vector() {
273        let a = vec![0.0_f32, 0.0];
274        let b = vec![1.0_f32, 0.0];
275        assert!(cosine_similarity(&a, &b).abs() <= f32::EPSILON);
276    }
277
278    #[test]
279    fn different_lengths() {
280        let a = vec![1.0_f32];
281        let b = vec![1.0_f32, 0.0];
282        assert!(cosine_similarity(&a, &b).abs() <= f32::EPSILON);
283    }
284
285    #[test]
286    fn empty_vectors() {
287        assert!(cosine_similarity(&[], &[]).abs() <= f32::EPSILON);
288    }
289
290    #[test]
291    fn parallel_vectors() {
292        let a = vec![2.0_f32, 0.0];
293        let b = vec![5.0_f32, 0.0];
294        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
295    }
296
297    #[test]
298    fn normalized_vectors() {
299        let s = 1.0_f32 / 2.0_f32.sqrt();
300        let a = vec![s, s];
301        let b = vec![s, s];
302        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
303    }
304
305    // ── EmbeddingVector tests ────────────────────────────────────────────────
306
307    #[test]
308    fn embedding_vector_normalize_produces_unit_vector() {
309        let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
310        let normed = raw.normalize();
311        let sum_sq: f32 = normed.as_slice().iter().map(|x| x * x).sum();
312        assert!((sum_sq - 1.0).abs() < 1e-6);
313    }
314
315    #[test]
316    fn embedding_vector_normalize_zero_vector_is_safe() {
317        let raw = EmbeddingVector::<Unnormalized>::new(vec![0.0_f32, 0.0]);
318        let normed = raw.normalize();
319        assert_eq!(normed.as_slice(), &[0.0_f32, 0.0]);
320    }
321
322    #[test]
323    fn embedding_vector_into_inner_roundtrip() {
324        let data = vec![1.0_f32, 2.0, 3.0];
325        let v = EmbeddingVector::<Unnormalized>::new(data.clone());
326        assert_eq!(v.into_inner(), data);
327    }
328
329    #[test]
330    fn embedding_vector_len_and_is_empty() {
331        let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
332        assert_eq!(v.len(), 2);
333        assert!(!v.is_empty());
334
335        let empty = EmbeddingVector::<Unnormalized>::new(vec![]);
336        assert!(empty.is_empty());
337    }
338
339    #[test]
340    fn embedding_vector_new_normalized_trust_caller() {
341        let v = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
342        assert_eq!(v.as_slice(), &[0.6_f32, 0.8]);
343    }
344
345    #[test]
346    fn embedding_vector_from_vec() {
347        let v: EmbeddingVector<Unnormalized> = vec![1.0_f32, 2.0].into();
348        assert_eq!(v.as_slice(), &[1.0_f32, 2.0]);
349    }
350}