1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::Result;
12
13pub trait DistanceMetric<F>
15where
16 F: Float + FromPrimitive + Debug + Send + Sync,
17{
18 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F;
20
21 fn pairwise_distances(&self, data: ArrayView2<F>) -> Array1<F> {
23 let n_samples = data.shape()[0];
24 let n_distances = n_samples * (n_samples - 1) / 2;
25 let mut distances = Array1::zeros(n_distances);
26
27 let mut idx = 0;
28 for i in 0..n_samples {
29 for j in (i + 1)..n_samples {
30 let x = data.row(i);
31 let y = data.row(j);
32 distances[idx] = self.distance(x, y);
33 idx += 1;
34 }
35 }
36 distances
37 }
38
39 fn distances_to_centroids(&self, data: ArrayView2<F>, centroids: ArrayView2<F>) -> Array2<F> {
41 let n_samples = data.shape()[0];
42 let n_centroids = centroids.shape()[0];
43 let mut distances = Array2::zeros((n_samples, n_centroids));
44
45 for i in 0..n_samples {
46 for j in 0..n_centroids {
47 let x = data.row(i);
48 let y = centroids.row(j);
49 distances[[i, j]] = self.distance(x, y);
50 }
51 }
52 distances
53 }
54
55 fn name(&self) -> &'static str;
57}
58
59#[derive(Debug, Clone, Default)]
61pub struct EuclideanDistance;
62
63impl<F> DistanceMetric<F> for EuclideanDistance
64where
65 F: Float + FromPrimitive + Debug + Send + Sync,
66{
67 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
68 let mut sum = F::zero();
69 for (a, b) in x.iter().zip(y.iter()) {
70 let diff = *a - *b;
71 sum = sum + diff * diff;
72 }
73 sum.sqrt()
74 }
75
76 fn name(&self) -> &'static str {
77 "euclidean"
78 }
79}
80
81#[derive(Debug, Clone, Default)]
83pub struct ManhattanDistance;
84
85impl<F> DistanceMetric<F> for ManhattanDistance
86where
87 F: Float + FromPrimitive + Debug + Send + Sync,
88{
89 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
90 let mut sum = F::zero();
91 for (a, b) in x.iter().zip(y.iter()) {
92 sum = sum + (*a - *b).abs();
93 }
94 sum
95 }
96
97 fn name(&self) -> &'static str {
98 "manhattan"
99 }
100}
101
102#[derive(Debug, Clone, Default)]
104pub struct ChebyshevDistance;
105
106impl<F> DistanceMetric<F> for ChebyshevDistance
107where
108 F: Float + FromPrimitive + Debug + Send + Sync,
109{
110 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
111 let mut max_diff = F::zero();
112 for (a, b) in x.iter().zip(y.iter()) {
113 let diff = (*a - *b).abs();
114 if diff > max_diff {
115 max_diff = diff;
116 }
117 }
118 max_diff
119 }
120
121 fn name(&self) -> &'static str {
122 "chebyshev"
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct MinkowskiDistance<F> {
129 pub p: F,
131}
132
133impl<F> MinkowskiDistance<F>
134where
135 F: Float + FromPrimitive + Debug,
136{
137 pub fn new(p: F) -> Self {
139 Self { p }
140 }
141}
142
143impl<F> DistanceMetric<F> for MinkowskiDistance<F>
144where
145 F: Float + FromPrimitive + Debug + Send + Sync,
146{
147 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
148 let mut sum = F::zero();
149 for (a, b) in x.iter().zip(y.iter()) {
150 sum = sum + (*a - *b).abs().powf(self.p);
151 }
152 sum.powf(F::one() / self.p)
153 }
154
155 fn name(&self) -> &'static str {
156 "minkowski"
157 }
158}
159
160#[derive(Debug, Clone, Default)]
162pub struct CosineDistance;
163
164impl<F> DistanceMetric<F> for CosineDistance
165where
166 F: Float + FromPrimitive + Debug + Send + Sync,
167{
168 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
169 let mut dot_product = F::zero();
170 let mut norm_x = F::zero();
171 let mut norm_y = F::zero();
172
173 for (a, b) in x.iter().zip(y.iter()) {
174 dot_product = dot_product + *a * *b;
175 norm_x = norm_x + *a * *a;
176 norm_y = norm_y + *b * *b;
177 }
178
179 norm_x = norm_x.sqrt();
180 norm_y = norm_y.sqrt();
181
182 if norm_x <= F::epsilon() || norm_y <= F::epsilon() {
183 return F::one(); }
185
186 let cosine_similarity = dot_product / (norm_x * norm_y);
187 let cosine_similarity = cosine_similarity.max(-F::one()).min(F::one());
189 F::one() - cosine_similarity
190 }
191
192 fn name(&self) -> &'static str {
193 "cosine"
194 }
195}
196
197#[derive(Debug, Clone, Default)]
199pub struct CorrelationDistance;
200
201impl<F> DistanceMetric<F> for CorrelationDistance
202where
203 F: Float + FromPrimitive + Debug + Send + Sync,
204{
205 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
206 let n = F::from(x.len()).unwrap();
207
208 let mean_x = x.sum() / n;
210 let mean_y = y.sum() / n;
211
212 let mut numerator = F::zero();
214 let mut sum_sq_x = F::zero();
215 let mut sum_sq_y = F::zero();
216
217 for (a, b) in x.iter().zip(y.iter()) {
218 let diff_x = *a - mean_x;
219 let diff_y = *b - mean_y;
220
221 numerator = numerator + diff_x * diff_y;
222 sum_sq_x = sum_sq_x + diff_x * diff_x;
223 sum_sq_y = sum_sq_y + diff_y * diff_y;
224 }
225
226 let denominator = (sum_sq_x * sum_sq_y).sqrt();
227
228 if denominator <= F::epsilon() {
229 return F::one(); }
231
232 let correlation = numerator / denominator;
233 let correlation = correlation.max(-F::one()).min(F::one());
235 F::one() - correlation
236 }
237
238 fn name(&self) -> &'static str {
239 "correlation"
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct MahalanobisDistance<F> {
246 pub inv_cov: Array2<F>,
248}
249
250impl<F> MahalanobisDistance<F>
251where
252 F: Float + FromPrimitive + Debug + Send + Sync + ScalarOperand,
253{
254 pub fn fromdata(data: ArrayView2<F>) -> Result<Self> {
264 let cov_matrix = compute_covariance_matrix(data)?;
265 let inv_cov = invert_matrix(cov_matrix)?;
266 Ok(Self { inv_cov })
267 }
268
269 pub fn from_inv_cov(_invcov: Array2<F>) -> Self {
271 Self { inv_cov: _invcov }
272 }
273}
274
275impl<F> DistanceMetric<F> for MahalanobisDistance<F>
276where
277 F: Float + FromPrimitive + Debug + Send + Sync + 'static,
278{
279 fn distance(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> F {
280 let diff = &x.to_owned() - &y.to_owned();
281 let temp = self.inv_cov.dot(&diff);
282 let result = diff.dot(&temp);
283 result.sqrt()
284 }
285
286 fn name(&self) -> &'static str {
287 "mahalanobis"
288 }
289}
290
291#[allow(dead_code)]
293fn compute_covariance_matrix<F>(data: ArrayView2<F>) -> Result<Array2<F>>
294where
295 F: Float + FromPrimitive + Debug + ScalarOperand,
296{
297 let n_samples = data.shape()[0];
298 let n_features = data.shape()[1];
299
300 if n_samples <= 1 {
301 return Err(crate::error::ClusteringError::InvalidInput(
302 "Need at least 2 samples to compute covariance matrix".into(),
303 ));
304 }
305
306 let means = data.mean_axis(Axis(0)).unwrap();
308
309 let mut centereddata = Array2::zeros((n_samples, n_features));
311 for i in 0..n_samples {
312 for j in 0..n_features {
313 centereddata[[i, j]] = data[[i, j]] - means[j];
314 }
315 }
316
317 let cov = centereddata.t().dot(¢ereddata) / F::from(n_samples - 1).unwrap();
319 Ok(cov)
320}
321
322#[allow(dead_code)]
324fn invert_matrix<F>(matrix: Array2<F>) -> Result<Array2<F>>
325where
326 F: Float + FromPrimitive + Debug + ScalarOperand,
327{
328 let n = matrix.shape()[0];
329 if n != matrix.shape()[1] {
330 return Err(crate::error::ClusteringError::InvalidInput(
331 "Matrix must be square for inversion".into(),
332 ));
333 }
334
335 let mut aug = Array2::zeros((n, 2 * n));
338
339 for i in 0..n {
341 for j in 0..n {
342 aug[[i, j]] = matrix[[i, j]];
343 }
344 aug[[i, n + i]] = F::one();
345 }
346
347 for i in 0..n {
349 let mut max_row = i;
351 for k in (i + 1)..n {
352 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
353 max_row = k;
354 }
355 }
356
357 if max_row != i {
359 for j in 0..(2 * n) {
360 let temp = aug[[i, j]];
361 aug[[i, j]] = aug[[max_row, j]];
362 aug[[max_row, j]] = temp;
363 }
364 }
365
366 if aug[[i, i]].abs() <= F::epsilon() {
368 return Err(crate::error::ClusteringError::ComputationError(
369 "Matrix is singular and cannot be inverted".into(),
370 ));
371 }
372
373 let pivot = aug[[i, i]];
375 for j in 0..(2 * n) {
376 aug[[i, j]] = aug[[i, j]] / pivot;
377 }
378
379 for k in 0..n {
381 if k != i {
382 let factor = aug[[k, i]];
383 for j in 0..(2 * n) {
384 aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
385 }
386 }
387 }
388 }
389
390 let mut inv = Array2::zeros((n, n));
392 for i in 0..n {
393 for j in 0..n {
394 inv[[i, j]] = aug[[i, n + j]];
395 }
396 }
397
398 Ok(inv)
399}
400
401#[derive(Debug, Clone, Copy, PartialEq, Eq)]
403pub enum MetricType {
404 Euclidean,
406 Manhattan,
408 Chebyshev,
410 Minkowski,
412 Cosine,
414 Correlation,
416 Mahalanobis,
418}
419
420#[allow(dead_code)]
422pub fn create_metric<F>(
423 metric_type: MetricType,
424 data: Option<ArrayView2<F>>,
425 p: Option<F>,
426) -> Result<Box<dyn DistanceMetric<F>>>
427where
428 F: Float + FromPrimitive + Debug + Send + Sync + ScalarOperand + 'static,
429{
430 match metric_type {
431 MetricType::Euclidean => Ok(Box::new(EuclideanDistance)),
432 MetricType::Manhattan => Ok(Box::new(ManhattanDistance)),
433 MetricType::Chebyshev => Ok(Box::new(ChebyshevDistance)),
434 MetricType::Minkowski => {
435 let p = p.unwrap_or_else(|| F::from(2.0).unwrap());
436 Ok(Box::new(MinkowskiDistance::new(p)))
437 }
438 MetricType::Cosine => Ok(Box::new(CosineDistance)),
439 MetricType::Correlation => Ok(Box::new(CorrelationDistance)),
440 MetricType::Mahalanobis => {
441 let data = data.ok_or_else(|| {
442 crate::error::ClusteringError::InvalidInput(
443 "Data required for Mahalanobis distance computation".into(),
444 )
445 })?;
446 let metric = MahalanobisDistance::fromdata(data)?;
447 Ok(Box::new(metric))
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use approx::assert_abs_diff_eq;
456 use scirs2_core::ndarray::Array2;
457
458 #[test]
459 fn test_euclidean_distance() {
460 let metric = EuclideanDistance;
461 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
462 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
463
464 let distance = metric.distance(x.view(), y.view());
465 let expected = ((3.0_f64).powi(2) * 3.0).sqrt(); assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
467 }
468
469 #[test]
470 fn test_manhattan_distance() {
471 let metric = ManhattanDistance;
472 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
473 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
474
475 let distance = metric.distance(x.view(), y.view());
476 let expected = 9.0; assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
478 }
479
480 #[test]
481 fn test_chebyshev_distance() {
482 let metric = ChebyshevDistance;
483 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
484 let y = Array1::from_vec(vec![4.0, 6.0, 5.0]);
485
486 let distance = metric.distance(x.view(), y.view());
487 let expected = 4.0; assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
489 }
490
491 #[test]
492 fn test_cosine_distance() {
493 let metric = CosineDistance;
494 let x = Array1::from_vec(vec![1.0, 0.0, 0.0]);
495 let y = Array1::from_vec(vec![0.0, 1.0, 0.0]);
496
497 let distance = metric.distance(x.view(), y.view());
498 let expected = 1.0; assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
500
501 let z = Array1::from_vec(vec![2.0, 0.0, 0.0]);
503 let distance_parallel = metric.distance(x.view(), z.view());
504 let expected_parallel = 0.0; assert_abs_diff_eq!(distance_parallel, expected_parallel, epsilon = 1e-10);
506 }
507
508 #[test]
509 fn test_mahalanobis_distance() {
510 let data = Array2::from_shape_vec(
512 (6, 2),
513 vec![1.0, 2.0, 2.0, 1.0, 3.0, 4.0, 4.0, 3.0, 5.0, 6.0, 6.0, 5.0],
514 )
515 .unwrap();
516
517 let metric = MahalanobisDistance::fromdata(data.view()).unwrap();
518
519 let x = Array1::from_vec(vec![1.0, 2.0]);
520 let y = Array1::from_vec(vec![2.0, 3.0]);
521
522 let distance = metric.distance(x.view(), y.view());
523
524 assert!(distance.is_finite());
526 assert!(distance >= 0.0);
527 }
528
529 #[test]
530 fn test_pairwise_distances() {
531 let metric = EuclideanDistance;
532 let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
533
534 let distances = metric.pairwise_distances(data.view());
535
536 assert_eq!(distances.len(), 3);
538
539 assert_abs_diff_eq!(distances[0], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(distances[1], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(distances[2], 2.0_f64.sqrt(), epsilon = 1e-10); }
544
545 #[test]
546 fn test_distances_to_centroids() {
547 let metric = EuclideanDistance;
548 let data = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
549
550 let centroids = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
551
552 let distances = metric.distances_to_centroids(data.view(), centroids.view());
553
554 assert_eq!(distances.shape(), &[2, 1]);
555 assert_abs_diff_eq!(
556 distances[[0, 0]],
557 (0.5_f64.powi(2) * 2.0).sqrt(),
558 epsilon = 1e-10
559 );
560 assert_abs_diff_eq!(
561 distances[[1, 0]],
562 (0.5_f64.powi(2) * 2.0).sqrt(),
563 epsilon = 1e-10
564 );
565 }
566}