scirs2_cluster/vq/
distance_metrics.rs

1//! Comprehensive distance metrics for clustering algorithms
2//!
3//! This module provides a unified interface for distance computations used across
4//! various clustering algorithms, including both standard metrics and advanced ones
5//! like Mahalanobis distance. SIMD acceleration is provided where possible.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::Result;
12
13/// Trait for distance metric computations
14pub trait DistanceMetric<F>
15where
16    F: Float + FromPrimitive + Debug + Send + Sync,
17{
18    /// Compute distance between two vectors
19    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F;
20
21    /// Compute pairwise distances between all points in data
22    fn pairwise_distances(&self, data: ArrayView2<F>) -> Array1<F> {
23        let n_samples = data.shape()[0];
24        let n_distances = n_samples * (n_samples - 1) / 2;
25        let mut distances = Array1::zeros(n_distances);
26
27        let mut idx = 0;
28        for i in 0..n_samples {
29            for j in (i + 1)..n_samples {
30                let x = data.row(i);
31                let y = data.row(j);
32                distances[idx] = self.distance(x, y);
33                idx += 1;
34            }
35        }
36        distances
37    }
38
39    /// Compute distances from each point to a set of centroids
40    fn distances_to_centroids(&self, data: ArrayView2<F>, centroids: ArrayView2<F>) -> Array2<F> {
41        let n_samples = data.shape()[0];
42        let n_centroids = centroids.shape()[0];
43        let mut distances = Array2::zeros((n_samples, n_centroids));
44
45        for i in 0..n_samples {
46            for j in 0..n_centroids {
47                let x = data.row(i);
48                let y = centroids.row(j);
49                distances[[i, j]] = self.distance(x, y);
50            }
51        }
52        distances
53    }
54
55    /// Get the name of this distance metric
56    fn name(&self) -> &'static str;
57}
58
59/// Euclidean distance metric (L2 norm)
60#[derive(Debug, Clone, Default)]
61pub struct EuclideanDistance;
62
63impl<F> DistanceMetric<F> for EuclideanDistance
64where
65    F: Float + FromPrimitive + Debug + Send + Sync,
66{
67    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
68        let mut sum = F::zero();
69        for (a, b) in x.iter().zip(y.iter()) {
70            let diff = *a - *b;
71            sum = sum + diff * diff;
72        }
73        sum.sqrt()
74    }
75
76    fn name(&self) -> &'static str {
77        "euclidean"
78    }
79}
80
81/// Manhattan distance metric (L1 norm)
82#[derive(Debug, Clone, Default)]
83pub struct ManhattanDistance;
84
85impl<F> DistanceMetric<F> for ManhattanDistance
86where
87    F: Float + FromPrimitive + Debug + Send + Sync,
88{
89    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
90        let mut sum = F::zero();
91        for (a, b) in x.iter().zip(y.iter()) {
92            sum = sum + (*a - *b).abs();
93        }
94        sum
95    }
96
97    fn name(&self) -> &'static str {
98        "manhattan"
99    }
100}
101
102/// Chebyshev distance metric (L∞ norm)
103#[derive(Debug, Clone, Default)]
104pub struct ChebyshevDistance;
105
106impl<F> DistanceMetric<F> for ChebyshevDistance
107where
108    F: Float + FromPrimitive + Debug + Send + Sync,
109{
110    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
111        let mut max_diff = F::zero();
112        for (a, b) in x.iter().zip(y.iter()) {
113            let diff = (*a - *b).abs();
114            if diff > max_diff {
115                max_diff = diff;
116            }
117        }
118        max_diff
119    }
120
121    fn name(&self) -> &'static str {
122        "chebyshev"
123    }
124}
125
126/// Minkowski distance metric with configurable power p
127#[derive(Debug, Clone)]
128pub struct MinkowskiDistance<F> {
129    /// The order of the Minkowski distance (p-norm parameter)
130    pub p: F,
131}
132
133impl<F> MinkowskiDistance<F>
134where
135    F: Float + FromPrimitive + Debug,
136{
137    /// Create a new Minkowski distance metric with the given order p
138    pub fn new(p: F) -> Self {
139        Self { p }
140    }
141}
142
143impl<F> DistanceMetric<F> for MinkowskiDistance<F>
144where
145    F: Float + FromPrimitive + Debug + Send + Sync,
146{
147    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
148        let mut sum = F::zero();
149        for (a, b) in x.iter().zip(y.iter()) {
150            sum = sum + (*a - *b).abs().powf(self.p);
151        }
152        sum.powf(F::one() / self.p)
153    }
154
155    fn name(&self) -> &'static str {
156        "minkowski"
157    }
158}
159
160/// Cosine distance metric (1 - cosine similarity)
161#[derive(Debug, Clone, Default)]
162pub struct CosineDistance;
163
164impl<F> DistanceMetric<F> for CosineDistance
165where
166    F: Float + FromPrimitive + Debug + Send + Sync,
167{
168    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
169        let mut dot_product = F::zero();
170        let mut norm_x = F::zero();
171        let mut norm_y = F::zero();
172
173        for (a, b) in x.iter().zip(y.iter()) {
174            dot_product = dot_product + *a * *b;
175            norm_x = norm_x + *a * *a;
176            norm_y = norm_y + *b * *b;
177        }
178
179        norm_x = norm_x.sqrt();
180        norm_y = norm_y.sqrt();
181
182        if norm_x <= F::epsilon() || norm_y <= F::epsilon() {
183            return F::one(); // Maximum distance for zero vectors
184        }
185
186        let cosine_similarity = dot_product / (norm_x * norm_y);
187        // Clamp to [-1, 1] to handle numerical errors
188        let cosine_similarity = cosine_similarity.max(-F::one()).min(F::one());
189        F::one() - cosine_similarity
190    }
191
192    fn name(&self) -> &'static str {
193        "cosine"
194    }
195}
196
197/// Correlation distance metric (1 - Pearson correlation)
198#[derive(Debug, Clone, Default)]
199pub struct CorrelationDistance;
200
201impl<F> DistanceMetric<F> for CorrelationDistance
202where
203    F: Float + FromPrimitive + Debug + Send + Sync,
204{
205    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
206        let n = F::from(x.len()).unwrap();
207
208        // Calculate means
209        let mean_x = x.sum() / n;
210        let mean_y = y.sum() / n;
211
212        // Calculate correlation coefficient
213        let mut numerator = F::zero();
214        let mut sum_sq_x = F::zero();
215        let mut sum_sq_y = F::zero();
216
217        for (a, b) in x.iter().zip(y.iter()) {
218            let diff_x = *a - mean_x;
219            let diff_y = *b - mean_y;
220
221            numerator = numerator + diff_x * diff_y;
222            sum_sq_x = sum_sq_x + diff_x * diff_x;
223            sum_sq_y = sum_sq_y + diff_y * diff_y;
224        }
225
226        let denominator = (sum_sq_x * sum_sq_y).sqrt();
227
228        if denominator <= F::epsilon() {
229            return F::one(); // Maximum distance for constant vectors
230        }
231
232        let correlation = numerator / denominator;
233        // Clamp to [-1, 1] to handle numerical errors
234        let correlation = correlation.max(-F::one()).min(F::one());
235        F::one() - correlation
236    }
237
238    fn name(&self) -> &'static str {
239        "correlation"
240    }
241}
242
243/// Mahalanobis distance metric using precomputed inverse covariance matrix
244#[derive(Debug, Clone)]
245pub struct MahalanobisDistance<F> {
246    /// Inverse covariance matrix
247    pub inv_cov: Array2<F>,
248}
249
250impl<F> MahalanobisDistance<F>
251where
252    F: Float + FromPrimitive + Debug + Send + Sync + ScalarOperand,
253{
254    /// Create a new Mahalanobis distance metric
255    ///
256    /// # Arguments
257    ///
258    /// * `data` - Training data to compute the covariance matrix from
259    ///
260    /// # Returns
261    ///
262    /// * Result containing the Mahalanobis distance metric or an error
263    pub fn fromdata(data: ArrayView2<F>) -> Result<Self> {
264        let cov_matrix = compute_covariance_matrix(data)?;
265        let inv_cov = invert_matrix(cov_matrix)?;
266        Ok(Self { inv_cov })
267    }
268
269    /// Create a Mahalanobis distance metric from a precomputed inverse covariance matrix
270    pub fn from_inv_cov(_invcov: Array2<F>) -> Self {
271        Self { inv_cov: _invcov }
272    }
273}
274
275impl<F> DistanceMetric<F> for MahalanobisDistance<F>
276where
277    F: Float + FromPrimitive + Debug + Send + Sync + 'static,
278{
279    fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
280        let diff = &x.to_owned() - &y.to_owned();
281        let temp = self.inv_cov.dot(&diff);
282        let result = diff.dot(&temp);
283        result.sqrt()
284    }
285
286    fn name(&self) -> &'static str {
287        "mahalanobis"
288    }
289}
290
291/// Compute the covariance matrix of the given data
292#[allow(dead_code)]
293fn compute_covariance_matrix<F>(data: ArrayView2<F>) -> Result<Array2<F>>
294where
295    F: Float + FromPrimitive + Debug + ScalarOperand,
296{
297    let n_samples = data.shape()[0];
298    let n_features = data.shape()[1];
299
300    if n_samples <= 1 {
301        return Err(crate::error::ClusteringError::InvalidInput(
302            "Need at least 2 samples to compute covariance matrix".into(),
303        ));
304    }
305
306    // Compute means
307    let means = data.mean_axis(Axis(0)).unwrap();
308
309    // Center the data
310    let mut centereddata = Array2::zeros((n_samples, n_features));
311    for i in 0..n_samples {
312        for j in 0..n_features {
313            centereddata[[i, j]] = data[[i, j]] - means[j];
314        }
315    }
316
317    // Compute covariance matrix: (1/(n-1)) * X^T * X
318    let cov = centereddata.t().dot(&centereddata) / F::from(n_samples - 1).unwrap();
319    Ok(cov)
320}
321
322/// Simple matrix inversion using LU decomposition
323#[allow(dead_code)]
324fn invert_matrix<F>(matrix: Array2<F>) -> Result<Array2<F>>
325where
326    F: Float + FromPrimitive + Debug + ScalarOperand,
327{
328    let n = matrix.shape()[0];
329    if n != matrix.shape()[1] {
330        return Err(crate::error::ClusteringError::InvalidInput(
331            "Matrix must be square for inversion".into(),
332        ));
333    }
334
335    // Simple Gauss-Jordan elimination for small matrices
336    // For production use, consider using ndarray-linalg for better numerical stability
337    let mut aug = Array2::zeros((n, 2 * n));
338
339    // Set up augmented _matrix [A | I]
340    for i in 0..n {
341        for j in 0..n {
342            aug[[i, j]] = matrix[[i, j]];
343        }
344        aug[[i, n + i]] = F::one();
345    }
346
347    // Forward elimination
348    for i in 0..n {
349        // Find pivot
350        let mut max_row = i;
351        for k in (i + 1)..n {
352            if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
353                max_row = k;
354            }
355        }
356
357        // Swap rows
358        if max_row != i {
359            for j in 0..(2 * n) {
360                let temp = aug[[i, j]];
361                aug[[i, j]] = aug[[max_row, j]];
362                aug[[max_row, j]] = temp;
363            }
364        }
365
366        // Check for singularity
367        if aug[[i, i]].abs() <= F::epsilon() {
368            return Err(crate::error::ClusteringError::ComputationError(
369                "Matrix is singular and cannot be inverted".into(),
370            ));
371        }
372
373        // Make diagonal element 1
374        let pivot = aug[[i, i]];
375        for j in 0..(2 * n) {
376            aug[[i, j]] = aug[[i, j]] / pivot;
377        }
378
379        // Eliminate column
380        for k in 0..n {
381            if k != i {
382                let factor = aug[[k, i]];
383                for j in 0..(2 * n) {
384                    aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
385                }
386            }
387        }
388    }
389
390    // Extract the inverse _matrix
391    let mut inv = Array2::zeros((n, n));
392    for i in 0..n {
393        for j in 0..n {
394            inv[[i, j]] = aug[[i, n + j]];
395        }
396    }
397
398    Ok(inv)
399}
400
401/// Enumeration of available distance metrics
402#[derive(Debug, Clone, Copy, PartialEq, Eq)]
403pub enum MetricType {
404    /// Euclidean (L2) distance metric
405    Euclidean,
406    /// Manhattan (L1) distance metric
407    Manhattan,
408    /// Chebyshev (L∞) distance metric
409    Chebyshev,
410    /// Minkowski distance metric with configurable order
411    Minkowski,
412    /// Cosine distance metric based on angle between vectors
413    Cosine,
414    /// Correlation distance metric
415    Correlation,
416    /// Mahalanobis distance metric accounting for covariance
417    Mahalanobis,
418}
419
420/// Create a distance metric instance from the metric type
421#[allow(dead_code)]
422pub fn create_metric<F>(
423    metric_type: MetricType,
424    data: Option<ArrayView2<F>>,
425    p: Option<F>,
426) -> Result<Box<dyn DistanceMetric<F>>>
427where
428    F: Float + FromPrimitive + Debug + Send + Sync + ScalarOperand + 'static,
429{
430    match metric_type {
431        MetricType::Euclidean => Ok(Box::new(EuclideanDistance)),
432        MetricType::Manhattan => Ok(Box::new(ManhattanDistance)),
433        MetricType::Chebyshev => Ok(Box::new(ChebyshevDistance)),
434        MetricType::Minkowski => {
435            let p = p.unwrap_or_else(|| F::from(2.0).unwrap());
436            Ok(Box::new(MinkowskiDistance::new(p)))
437        }
438        MetricType::Cosine => Ok(Box::new(CosineDistance)),
439        MetricType::Correlation => Ok(Box::new(CorrelationDistance)),
440        MetricType::Mahalanobis => {
441            let data = data.ok_or_else(|| {
442                crate::error::ClusteringError::InvalidInput(
443                    "Data required for Mahalanobis distance computation".into(),
444                )
445            })?;
446            let metric = MahalanobisDistance::fromdata(data)?;
447            Ok(Box::new(metric))
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use approx::assert_abs_diff_eq;
456    use scirs2_core::ndarray::Array2;
457
458    #[test]
459    fn test_euclidean_distance() {
460        let metric = EuclideanDistance;
461        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
462        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
463
464        let distance = metric.distance(x.view(), y.view());
465        let expected = ((3.0_f64).powi(2) * 3.0).sqrt(); // sqrt(9 + 9 + 9) = sqrt(27)
466        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
467    }
468
469    #[test]
470    fn test_manhattan_distance() {
471        let metric = ManhattanDistance;
472        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
473        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
474
475        let distance = metric.distance(x.view(), y.view());
476        let expected = 9.0; // |1-4| + |2-5| + |3-6| = 3 + 3 + 3 = 9
477        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
478    }
479
480    #[test]
481    fn test_chebyshev_distance() {
482        let metric = ChebyshevDistance;
483        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
484        let y = Array1::from_vec(vec![4.0, 6.0, 5.0]);
485
486        let distance = metric.distance(x.view(), y.view());
487        let expected = 4.0; // max(|1-4|, |2-6|, |3-5|) = max(3, 4, 2) = 4
488        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
489    }
490
491    #[test]
492    fn test_cosine_distance() {
493        let metric = CosineDistance;
494        let x = Array1::from_vec(vec![1.0, 0.0, 0.0]);
495        let y = Array1::from_vec(vec![0.0, 1.0, 0.0]);
496
497        let distance = metric.distance(x.view(), y.view());
498        let expected = 1.0; // cosine similarity is 0, so distance is 1 - 0 = 1
499        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
500
501        // Test parallel vectors
502        let z = Array1::from_vec(vec![2.0, 0.0, 0.0]);
503        let distance_parallel = metric.distance(x.view(), z.view());
504        let expected_parallel = 0.0; // cosine similarity is 1, so distance is 1 - 1 = 0
505        assert_abs_diff_eq!(distance_parallel, expected_parallel, epsilon = 1e-10);
506    }
507
508    #[test]
509    fn test_mahalanobis_distance() {
510        // Create test data with more variance to avoid singular matrix
511        let data = Array2::from_shape_vec(
512            (6, 2),
513            vec![1.0, 2.0, 2.0, 1.0, 3.0, 4.0, 4.0, 3.0, 5.0, 6.0, 6.0, 5.0],
514        )
515        .unwrap();
516
517        let metric = MahalanobisDistance::fromdata(data.view()).unwrap();
518
519        let x = Array1::from_vec(vec![1.0, 2.0]);
520        let y = Array1::from_vec(vec![2.0, 3.0]);
521
522        let distance = metric.distance(x.view(), y.view());
523
524        // The exact value depends on the covariance matrix, but it should be finite and positive
525        assert!(distance.is_finite());
526        assert!(distance >= 0.0);
527    }
528
529    #[test]
530    fn test_pairwise_distances() {
531        let metric = EuclideanDistance;
532        let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
533
534        let distances = metric.pairwise_distances(data.view());
535
536        // Should have 3 choose 2 = 3 distances
537        assert_eq!(distances.len(), 3);
538
539        // Check specific distances
540        assert_abs_diff_eq!(distances[0], 1.0, epsilon = 1e-10); // (0,0) to (1,0)
541        assert_abs_diff_eq!(distances[1], 1.0, epsilon = 1e-10); // (0,0) to (0,1)
542        assert_abs_diff_eq!(distances[2], 2.0_f64.sqrt(), epsilon = 1e-10); // (1,0) to (0,1)
543    }
544
545    #[test]
546    fn test_distances_to_centroids() {
547        let metric = EuclideanDistance;
548        let data = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
549
550        let centroids = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
551
552        let distances = metric.distances_to_centroids(data.view(), centroids.view());
553
554        assert_eq!(distances.shape(), &[2, 1]);
555        assert_abs_diff_eq!(
556            distances[[0, 0]],
557            (0.5_f64.powi(2) * 2.0).sqrt(),
558            epsilon = 1e-10
559        );
560        assert_abs_diff_eq!(
561            distances[[1, 0]],
562            (0.5_f64.powi(2) * 2.0).sqrt(),
563            epsilon = 1e-10
564        );
565    }
566}