zeph_common/math.rs
1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Mathematical utilities for vector operations.
5//!
6//! This module provides general-purpose vector math as well as the
7//! [`EmbeddingVector<State>`] typestate wrapper that encodes L2-normalization at the
8//! type level. Use [`EmbeddingVector::<Normalized>`] as the required parameter type on
9//! functions that feed vectors directly into Qdrant cosine-distance searches.
10
11use std::marker::PhantomData;
12
13// ── Typestate markers ────────────────────────────────────────────────────────
14
15/// Typestate marker indicating that an [`EmbeddingVector`] has been L2-normalized
16/// to unit length.
17///
18/// This marker cannot be constructed outside this module — it can only be created
19/// by [`EmbeddingVector::normalize`] or the trust-caller constructor
20/// [`EmbeddingVector::<Normalized>::new_normalized`].
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct Normalized(());
23
24/// Typestate marker indicating that an [`EmbeddingVector`] has **not** been
25/// normalized yet (raw output from a model or loaded from storage).
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub struct Unnormalized(());
28
29// ── EmbeddingVector ──────────────────────────────────────────────────────────
30
31/// An embedding vector tagged with a normalization-state marker.
32///
33/// The type parameter encodes whether the vector is L2-normalized:
34///
35/// - `EmbeddingVector<Unnormalized>` — raw model output; must be normalized before
36/// passing to cosine-distance Qdrant searches.
37/// - `EmbeddingVector<Normalized>` — unit-length vector, safe to pass directly to
38/// Qdrant gRPC cosine queries.
39///
40/// Using [`Normalized`] as a required parameter type at the Qdrant search boundary
41/// turns dimension/normalization mismatches into compile-time errors rather than
42/// silent near-zero similarity scores (see bugs #3421, #3382, #3420, #3422).
43///
44/// # Construction
45///
46/// ```
47/// use zeph_common::math::{EmbeddingVector, Normalized, Unnormalized};
48///
49/// // Wrap a raw model vector and normalize it.
50/// let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
51/// let normalized = raw.normalize();
52/// let slice = normalized.as_slice();
53/// // A normalized vector has unit L2 length.
54/// let norm: f32 = slice.iter().map(|x| x * x).sum::<f32>().sqrt();
55/// assert!((norm - 1.0).abs() < 1e-6);
56///
57/// // Trust-caller constructor for models that always return unit vectors.
58/// let trusted = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
59/// assert_eq!(trusted.as_slice(), &[0.6_f32, 0.8]);
60/// ```
61#[derive(Debug, Clone)]
62pub struct EmbeddingVector<State> {
63 inner: Vec<f32>,
64 _state: PhantomData<State>,
65}
66
67impl EmbeddingVector<Unnormalized> {
68 /// Wrap a raw embedding vector from a model or storage.
69 ///
70 /// The returned vector is tagged `Unnormalized`. Call [`normalize`](Self::normalize)
71 /// before passing it to functions that require [`Normalized`].
72 ///
73 /// # Examples
74 ///
75 /// ```
76 /// use zeph_common::math::{EmbeddingVector, Unnormalized};
77 ///
78 /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 0.0]);
79 /// assert_eq!(v.as_slice(), &[1.0_f32, 0.0]);
80 /// ```
81 #[must_use]
82 pub fn new(inner: Vec<f32>) -> Self {
83 Self {
84 inner,
85 _state: PhantomData,
86 }
87 }
88
89 /// L2-normalize this vector and return an [`EmbeddingVector<Normalized>`].
90 ///
91 /// If the vector is a zero vector (L2 norm is zero), all elements are set to zero
92 /// to avoid division by zero; the result is technically invalid for cosine search
93 /// but is safe and consistent with the behavior of [`cosine_similarity`].
94 ///
95 /// # Examples
96 ///
97 /// ```
98 /// use zeph_common::math::{EmbeddingVector, Unnormalized};
99 ///
100 /// let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
101 /// let norm = raw.normalize();
102 /// let sum_sq: f32 = norm.as_slice().iter().map(|x| x * x).sum();
103 /// assert!((sum_sq - 1.0).abs() < 1e-6, "must be unit length");
104 /// ```
105 #[must_use]
106 pub fn normalize(self) -> EmbeddingVector<Normalized> {
107 let norm: f32 = self.inner.iter().map(|x| x * x).sum::<f32>().sqrt();
108 let normalized = if norm < f32::EPSILON {
109 self.inner
110 } else {
111 self.inner.into_iter().map(|x| x / norm).collect()
112 };
113 EmbeddingVector {
114 inner: normalized,
115 _state: PhantomData,
116 }
117 }
118}
119
120impl EmbeddingVector<Normalized> {
121 /// Construct a normalized embedding vector, trusting the caller that `inner` is
122 /// already L2-unit-length.
123 ///
124 /// Use this constructor only when the source guarantees unit-length output (e.g., a
125 /// model that always normalizes, or a vector loaded from a store known to hold
126 /// normalized data). Incorrect use does **not** cause UB but will produce wrong
127 /// cosine scores.
128 ///
129 /// # Examples
130 ///
131 /// ```
132 /// use zeph_common::math::{EmbeddingVector, Normalized};
133 ///
134 /// // Suppose the model always returns unit vectors.
135 /// let v = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
136 /// assert_eq!(v.as_slice(), &[0.6_f32, 0.8]);
137 /// ```
138 #[must_use]
139 pub fn new_normalized(inner: Vec<f32>) -> Self {
140 Self {
141 inner,
142 _state: PhantomData,
143 }
144 }
145}
146
147impl<State> EmbeddingVector<State> {
148 /// Return a borrowed slice of the vector elements.
149 ///
150 /// # Examples
151 ///
152 /// ```
153 /// use zeph_common::math::{EmbeddingVector, Unnormalized};
154 ///
155 /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
156 /// assert_eq!(v.as_slice(), &[1.0_f32, 2.0]);
157 /// ```
158 #[must_use]
159 pub fn as_slice(&self) -> &[f32] {
160 &self.inner
161 }
162
163 /// Consume the wrapper and return the underlying `Vec<f32>`.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// use zeph_common::math::{EmbeddingVector, Unnormalized};
169 ///
170 /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
171 /// assert_eq!(v.into_inner(), vec![1.0_f32, 2.0]);
172 /// ```
173 #[must_use]
174 pub fn into_inner(self) -> Vec<f32> {
175 self.inner
176 }
177
178 /// Return the number of dimensions in this vector.
179 ///
180 /// # Examples
181 ///
182 /// ```
183 /// use zeph_common::math::{EmbeddingVector, Unnormalized};
184 ///
185 /// let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0, 3.0]);
186 /// assert_eq!(v.len(), 3);
187 /// ```
188 #[must_use]
189 pub fn len(&self) -> usize {
190 self.inner.len()
191 }
192
193 /// Returns `true` if the vector has no elements.
194 ///
195 /// # Examples
196 ///
197 /// ```
198 /// use zeph_common::math::{EmbeddingVector, Unnormalized};
199 ///
200 /// let v = EmbeddingVector::<Unnormalized>::new(vec![]);
201 /// assert!(v.is_empty());
202 /// ```
203 #[must_use]
204 pub fn is_empty(&self) -> bool {
205 self.inner.is_empty()
206 }
207}
208
209impl From<Vec<f32>> for EmbeddingVector<Unnormalized> {
210 fn from(v: Vec<f32>) -> Self {
211 Self::new(v)
212 }
213}
214
215/// Compute cosine similarity between two equal-length f32 vectors.
216///
217/// Returns `0.0` if the vectors have different lengths, are empty, or if
218/// either vector is a zero vector.
219///
220/// Uses a single-pass loop for efficiency.
221#[inline]
222#[must_use]
223pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
224 if a.len() != b.len() || a.is_empty() {
225 return 0.0;
226 }
227 debug_assert_eq!(a.len(), b.len(), "cosine_similarity: length mismatch");
228
229 let mut dot = 0.0_f32;
230 let mut norm_a = 0.0_f32;
231 let mut norm_b = 0.0_f32;
232
233 for (x, y) in a.iter().zip(b.iter()) {
234 dot += x * y;
235 norm_a += x * x;
236 norm_b += y * y;
237 }
238
239 let denom = norm_a.sqrt() * norm_b.sqrt();
240 if denom < f32::EPSILON {
241 return 0.0;
242 }
243
244 (dot / denom).clamp(-1.0, 1.0)
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn identical_vectors() {
253 let v = vec![1.0_f32, 2.0, 3.0];
254 assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
255 }
256
257 #[test]
258 fn orthogonal_vectors() {
259 let a = vec![1.0_f32, 0.0];
260 let b = vec![0.0_f32, 1.0];
261 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
262 }
263
264 #[test]
265 fn opposite_vectors() {
266 let a = vec![1.0_f32, 0.0];
267 let b = vec![-1.0_f32, 0.0];
268 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
269 }
270
271 #[test]
272 fn zero_vector() {
273 let a = vec![0.0_f32, 0.0];
274 let b = vec![1.0_f32, 0.0];
275 assert!(cosine_similarity(&a, &b).abs() <= f32::EPSILON);
276 }
277
278 #[test]
279 fn different_lengths() {
280 let a = vec![1.0_f32];
281 let b = vec![1.0_f32, 0.0];
282 assert!(cosine_similarity(&a, &b).abs() <= f32::EPSILON);
283 }
284
285 #[test]
286 fn empty_vectors() {
287 assert!(cosine_similarity(&[], &[]).abs() <= f32::EPSILON);
288 }
289
290 #[test]
291 fn parallel_vectors() {
292 let a = vec![2.0_f32, 0.0];
293 let b = vec![5.0_f32, 0.0];
294 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
295 }
296
297 #[test]
298 fn normalized_vectors() {
299 let s = 1.0_f32 / 2.0_f32.sqrt();
300 let a = vec![s, s];
301 let b = vec![s, s];
302 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
303 }
304
305 // ── EmbeddingVector tests ────────────────────────────────────────────────
306
307 #[test]
308 fn embedding_vector_normalize_produces_unit_vector() {
309 let raw = EmbeddingVector::<Unnormalized>::new(vec![3.0_f32, 4.0]);
310 let normed = raw.normalize();
311 let sum_sq: f32 = normed.as_slice().iter().map(|x| x * x).sum();
312 assert!((sum_sq - 1.0).abs() < 1e-6);
313 }
314
315 #[test]
316 fn embedding_vector_normalize_zero_vector_is_safe() {
317 let raw = EmbeddingVector::<Unnormalized>::new(vec![0.0_f32, 0.0]);
318 let normed = raw.normalize();
319 assert_eq!(normed.as_slice(), &[0.0_f32, 0.0]);
320 }
321
322 #[test]
323 fn embedding_vector_into_inner_roundtrip() {
324 let data = vec![1.0_f32, 2.0, 3.0];
325 let v = EmbeddingVector::<Unnormalized>::new(data.clone());
326 assert_eq!(v.into_inner(), data);
327 }
328
329 #[test]
330 fn embedding_vector_len_and_is_empty() {
331 let v = EmbeddingVector::<Unnormalized>::new(vec![1.0_f32, 2.0]);
332 assert_eq!(v.len(), 2);
333 assert!(!v.is_empty());
334
335 let empty = EmbeddingVector::<Unnormalized>::new(vec![]);
336 assert!(empty.is_empty());
337 }
338
339 #[test]
340 fn embedding_vector_new_normalized_trust_caller() {
341 let v = EmbeddingVector::<Normalized>::new_normalized(vec![0.6_f32, 0.8]);
342 assert_eq!(v.as_slice(), &[0.6_f32, 0.8]);
343 }
344
345 #[test]
346 fn embedding_vector_from_vec() {
347 let v: EmbeddingVector<Unnormalized> = vec![1.0_f32, 2.0].into();
348 assert_eq!(v.as_slice(), &[1.0_f32, 2.0]);
349 }
350}