Skip to main content

vectrust_core/
vector_ops.rs

1use crate::*;
2
3/// Vector similarity calculations optimized for different distance metrics
4pub struct VectorOps;
5
6impl VectorOps {
7    /// Calculate cosine similarity between two vectors
8    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
9        if a.len() != b.len() || a.is_empty() {
10            return 0.0;
11        }
12        
13        let mut dot_product = 0.0;
14        let mut norm_a = 0.0;
15        let mut norm_b = 0.0;
16        
17        // Vectorized computation - compiler will auto-vectorize this loop
18        for i in 0..a.len() {
19            dot_product += a[i] * b[i];
20            norm_a += a[i] * a[i];
21            norm_b += b[i] * b[i];
22        }
23        
24        if norm_a == 0.0 || norm_b == 0.0 {
25            return 0.0;
26        }
27        
28        dot_product / (norm_a.sqrt() * norm_b.sqrt())
29    }
30    
31    /// Calculate Euclidean distance between two vectors
32    pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
33        if a.len() != b.len() {
34            return f32::INFINITY;
35        }
36        
37        let mut sum_sq = 0.0;
38        for i in 0..a.len() {
39            let diff = a[i] - b[i];
40            sum_sq += diff * diff;
41        }
42        
43        sum_sq.sqrt()
44    }
45    
46    /// Calculate dot product between two vectors
47    pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
48        if a.len() != b.len() {
49            return 0.0;
50        }
51        
52        let mut product = 0.0;
53        for i in 0..a.len() {
54            product += a[i] * b[i];
55        }
56        
57        product
58    }
59    
60    /// Calculate similarity based on the specified distance metric
61    pub fn calculate_similarity(a: &[f32], b: &[f32], metric: &DistanceMetric) -> f32 {
62        match metric {
63            DistanceMetric::Cosine => Self::cosine_similarity(a, b),
64            DistanceMetric::Euclidean => {
65                // Convert distance to similarity (higher is better)
66                let distance = Self::euclidean_distance(a, b);
67                if distance == 0.0 {
68                    1.0
69                } else {
70                    1.0 / (1.0 + distance)
71                }
72            },
73            DistanceMetric::DotProduct => Self::dot_product(a, b),
74        }
75    }
76    
77    /// Normalize a vector to unit length
78    pub fn normalize(vector: &mut [f32]) {
79        let norm = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
80        if norm > 0.0 {
81            for x in vector.iter_mut() {
82                *x /= norm;
83            }
84        }
85    }
86    
87    /// Create a normalized copy of a vector
88    pub fn normalized(vector: &[f32]) -> Vec<f32> {
89        let mut result = vector.to_vec();
90        Self::normalize(&mut result);
91        result
92    }
93    
94    /// Check if two vectors have the same dimensions
95    pub fn compatible_dimensions(a: &[f32], b: &[f32]) -> bool {
96        a.len() == b.len() && !a.is_empty()
97    }
98    
99    /// Validate vector for NaN or infinite values
100    pub fn is_valid_vector(vector: &[f32]) -> bool {
101        vector.iter().all(|&x| x.is_finite())
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    
109    #[test]
110    fn test_cosine_similarity() {
111        let a = vec![1.0, 0.0, 0.0];
112        let b = vec![1.0, 0.0, 0.0];
113        assert!((VectorOps::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
114        
115        let c = vec![0.0, 1.0, 0.0];
116        assert!((VectorOps::cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
117    }
118    
119    #[test]
120    fn test_euclidean_distance() {
121        let a = vec![0.0, 0.0, 0.0];
122        let b = vec![3.0, 4.0, 0.0];
123        assert!((VectorOps::euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
124    }
125    
126    #[test]
127    fn test_normalization() {
128        let mut vector = vec![3.0, 4.0, 0.0];
129        VectorOps::normalize(&mut vector);
130        let norm = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
131        assert!((norm - 1.0).abs() < 1e-6);
132    }
133}