vectx_core/
vector.rs

1use serde::{Deserialize, Serialize};
2use std::ops::{Add, Mul, Sub};
3
4/// A vector of floating point numbers
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
6pub struct Vector {
7    data: Vec<f32>,
8}
9
10impl Vector {
11    #[inline]
12    #[must_use]
13    pub fn new(data: Vec<f32>) -> Self {
14        Self { data }
15    }
16
17    #[inline]
18    #[must_use]
19    pub fn from_slice(data: &[f32]) -> Self {
20        Self {
21            data: data.to_vec(),
22        }
23    }
24
25    #[inline]
26    #[must_use]
27    pub fn dim(&self) -> usize {
28        self.data.len()
29    }
30
31    #[inline]
32    #[must_use]
33    pub fn is_empty(&self) -> bool {
34        self.data.is_empty()
35    }
36
37    #[inline]
38    #[must_use]
39    pub fn as_slice(&self) -> &[f32] {
40        &self.data
41    }
42
43    #[inline]
44    pub fn as_mut_slice(&mut self) -> &mut [f32] {
45        &mut self.data
46    }
47
48    /// Compute cosine similarity with another vector
49    /// Uses SIMD-optimized operations for both dot product and norms
50    #[inline]
51    pub fn cosine_similarity(&self, other: &Vector) -> f32 {
52        if self.dim() != other.dim() {
53            return 0.0;
54        }
55
56        let dot_product = crate::simd::dot_product_simd(&self.data, &other.data);
57
58        // Use SIMD-optimized norm calculation
59        let norm_a = crate::simd::norm_simd(&self.data);
60        let norm_b = crate::simd::norm_simd(&other.data);
61
62        if norm_a == 0.0 || norm_b == 0.0 {
63            return 0.0;
64        }
65
66        dot_product / (norm_a * norm_b)
67    }
68
69    /// Compute L2 (Euclidean) distance
70    #[inline]
71    pub fn l2_distance(&self, other: &Vector) -> f32 {
72        if self.dim() != other.dim() {
73            return f32::INFINITY;
74        }
75
76        crate::simd::l2_distance_simd(&self.data, &other.data)
77    }
78
79    /// Normalize the vector to unit length
80    /// Uses SIMD-optimized norm calculation
81    #[inline]
82    pub fn normalize(&mut self) {
83        let norm = crate::simd::norm_simd(&self.data);
84        if norm > f32::EPSILON {
85            let inv_norm = 1.0 / norm;
86            for x in &mut self.data {
87                *x *= inv_norm;
88            }
89        }
90    }
91
92    /// Get normalized copy
93    #[inline]
94    #[must_use]
95    pub fn normalized(&self) -> Self {
96        let mut v = self.clone();
97        v.normalize();
98        v
99    }
100}
101
102impl Add for &Vector {
103    type Output = Vector;
104
105    fn add(self, other: &Vector) -> Vector {
106        assert_eq!(self.dim(), other.dim());
107        Vector::new(
108            self.data
109                .iter()
110                .zip(other.data.iter())
111                .map(|(a, b)| a + b)
112                .collect(),
113        )
114    }
115}
116
117impl Sub for &Vector {
118    type Output = Vector;
119
120    fn sub(self, other: &Vector) -> Vector {
121        assert_eq!(self.dim(), other.dim());
122        Vector::new(
123            self.data
124                .iter()
125                .zip(other.data.iter())
126                .map(|(a, b)| a - b)
127                .collect(),
128        )
129    }
130}
131
132impl Mul<f32> for &Vector {
133    type Output = Vector;
134
135    fn mul(self, scalar: f32) -> Vector {
136        Vector::new(self.data.iter().map(|x| x * scalar).collect())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_cosine_similarity() {
146        let v1 = Vector::new(vec![1.0, 0.0]);
147        let v2 = Vector::new(vec![1.0, 0.0]);
148        assert!((v1.cosine_similarity(&v2) - 1.0).abs() < 1e-6);
149
150        let v3 = Vector::new(vec![1.0, 0.0]);
151        let v4 = Vector::new(vec![0.0, 1.0]);
152        assert!((v3.cosine_similarity(&v4) - 0.0).abs() < 1e-6);
153    }
154
155    #[test]
156    fn test_l2_distance() {
157        let v1 = Vector::new(vec![0.0, 0.0]);
158        let v2 = Vector::new(vec![3.0, 4.0]);
159        assert!((v1.l2_distance(&v2) - 5.0).abs() < 1e-6);
160    }
161}
162