Skip to main content

terminals_core/primitives/
vector.rs

1/// Cosine similarity between two vectors.
2pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
3    debug_assert_eq!(a.len(), b.len());
4    let mut dot = 0.0f32;
5    let mut mag_a = 0.0f32;
6    let mut mag_b = 0.0f32;
7    for i in 0..a.len() {
8        dot += a[i] * b[i];
9        mag_a += a[i] * a[i];
10        mag_b += b[i] * b[i];
11    }
12    let denom = mag_a.sqrt() * mag_b.sqrt();
13    if denom < 1e-10 {
14        0.0
15    } else {
16        dot / denom
17    }
18}
19
20/// L2 normalize in-place. Zero vectors remain zero.
21pub fn normalize(v: &mut [f32]) {
22    let mag: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
23    if mag > 1e-10 {
24        for x in v.iter_mut() {
25            *x /= mag;
26        }
27    }
28}
29
30/// Euclidean distance between two vectors.
31pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
32    debug_assert_eq!(a.len(), b.len());
33    a.iter()
34        .zip(b.iter())
35        .map(|(x, y)| (x - y) * (x - y))
36        .sum::<f32>()
37        .sqrt()
38}
39
40/// Top-K nearest vectors by cosine similarity (descending).
41pub fn top_k(query: &[f32], candidates: &[&[f32]], k: usize) -> Vec<(usize, f32)> {
42    let mut scores: Vec<(usize, f32)> = candidates
43        .iter()
44        .enumerate()
45        .map(|(i, c)| (i, cosine_similarity(query, c)))
46        .collect();
47    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
48    scores.truncate(k);
49    scores
50}
51
52/// Centroid (element-wise mean) of N vectors.
53pub fn centroid(vectors: &[&[f32]]) -> Vec<f32> {
54    if vectors.is_empty() {
55        return vec![];
56    }
57    let dim = vectors[0].len();
58    let n = vectors.len() as f32;
59    let mut result = vec![0.0f32; dim];
60    for v in vectors {
61        for (i, &x) in v.iter().enumerate() {
62            result[i] += x;
63        }
64    }
65    for x in result.iter_mut() {
66        *x /= n;
67    }
68    result
69}
70
71/// Sanitize NaN values to 0.0.
72pub fn sanitize_nan(v: &mut [f32]) {
73    for x in v.iter_mut() {
74        if x.is_nan() {
75            *x = 0.0;
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_cosine_similarity_identical() {
86        let a = vec![1.0f32, 0.0, 0.0];
87        let b = vec![1.0f32, 0.0, 0.0];
88        let sim = cosine_similarity(&a, &b);
89        assert!((sim - 1.0).abs() < 1e-6);
90    }
91
92    #[test]
93    fn test_cosine_similarity_orthogonal() {
94        let a = vec![1.0f32, 0.0, 0.0];
95        let b = vec![0.0f32, 1.0, 0.0];
96        let sim = cosine_similarity(&a, &b);
97        assert!(sim.abs() < 1e-6);
98    }
99
100    #[test]
101    fn test_cosine_similarity_384_dim() {
102        let mut a = vec![0.0f32; 384];
103        let mut b = vec![0.0f32; 384];
104        a[0] = 1.0;
105        b[0] = 0.5;
106        b[1] = 0.866;
107        let sim = cosine_similarity(&a, &b);
108        assert!((sim - 0.5).abs() < 1e-3);
109    }
110
111    #[test]
112    fn test_normalize() {
113        let mut v = vec![3.0f32, 4.0];
114        normalize(&mut v);
115        assert!((v[0] - 0.6).abs() < 1e-6);
116        assert!((v[1] - 0.8).abs() < 1e-6);
117    }
118
119    #[test]
120    fn test_normalize_zero_vector() {
121        let mut v = vec![0.0f32; 384];
122        normalize(&mut v);
123        assert!(v.iter().all(|&x| x == 0.0));
124    }
125
126    #[test]
127    fn test_euclidean_distance() {
128        let a = vec![0.0f32, 0.0];
129        let b = vec![3.0f32, 4.0];
130        let d = euclidean_distance(&a, &b);
131        assert!((d - 5.0).abs() < 1e-6);
132    }
133
134    #[test]
135    fn test_top_k() {
136        let query = vec![1.0f32, 0.0, 0.0];
137        let candidates: Vec<Vec<f32>> = vec![
138            vec![0.0, 1.0, 0.0], // orthogonal
139            vec![1.0, 0.0, 0.0], // identical
140            vec![0.5, 0.5, 0.0], // partial
141        ];
142        let refs: Vec<&[f32]> = candidates.iter().map(|v| v.as_slice()).collect();
143        let results = top_k(&query, &refs, 2);
144        assert_eq!(results.len(), 2);
145        assert_eq!(results[0].0, 1); // identical first
146    }
147
148    #[test]
149    fn test_centroid() {
150        let vecs: Vec<Vec<f32>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
151        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
152        let c = centroid(&refs);
153        assert!((c[0] - 0.5).abs() < 1e-6);
154        assert!((c[1] - 0.5).abs() < 1e-6);
155    }
156}