1use crate::error::{SpatialError, SpatialResult};
39use crate::generic_traits::{DistanceMetric, Point, SpatialPoint, SpatialScalar};
40use scirs2_core::numeric::{Float, NumCast};
41use scirs2_core::parallel_ops::*;
42use scirs2_core::simd_ops::PlatformCapabilities;
43use std::cmp::Ordering;
44use std::collections::BinaryHeap;
45use std::marker::PhantomData;
46use std::sync::Arc;
47
48#[derive(Debug, Clone)]
54pub struct GenericKDTree<T: SpatialScalar, P: SpatialPoint<T>> {
55 root: Option<Box<KDNode<T, P>>>,
56 points: Vec<P>,
57 dimension: usize,
58 #[allow(dead_code)]
59 leaf_size: usize,
60}
61
62#[derive(Debug, Clone)]
63struct KDNode<T: SpatialScalar, P: SpatialPoint<T>> {
64 point_index: usize,
65 splitting_dimension: usize,
66 left: Option<Box<KDNode<T, P>>>,
67 right: Option<Box<KDNode<T, P>>>,
68 _phantom: PhantomData<(T, P)>,
69}
70
71impl<T: SpatialScalar, P: SpatialPoint<T> + Clone> GenericKDTree<T, P> {
72 pub fn new(points: &[P]) -> SpatialResult<Self> {
74 if points.is_empty() {
75 return Ok(Self {
76 root: None,
77 points: Vec::new(),
78 dimension: 0,
79 leaf_size: 32,
80 });
81 }
82
83 if points.len() > 1_000_000 {
84 return Err(SpatialError::ValueError(format!(
85 "Point collection too large: {} points. Maximum supported: 1,000,000",
86 points.len()
87 )));
88 }
89
90 let dimension = points[0].dimension();
91 if dimension == 0 {
92 return Err(SpatialError::ValueError(
93 "Points must have at least one dimension".to_string(),
94 ));
95 }
96
97 if dimension > 50 {
98 return Err(SpatialError::ValueError(format!(
99 "Dimension too high: {dimension}. KD-Tree is not efficient for dimensions > 50"
100 )));
101 }
102
103 for (i, point) in points.iter().enumerate() {
105 if point.dimension() != dimension {
106 return Err(SpatialError::ValueError(format!(
107 "Point {} has dimension {} but expected {}",
108 i,
109 point.dimension(),
110 dimension
111 )));
112 }
113
114 for d in 0..dimension {
116 if let Some(coord) = point.coordinate(d) {
117 if !Float::is_finite(coord) {
118 return Err(SpatialError::ValueError(format!(
119 "Point {} has invalid coordinate {} at dimension {}",
120 i,
121 NumCast::from(coord).unwrap_or(f64::NAN),
122 d
123 )));
124 }
125 }
126 }
127 }
128
129 let points = points.to_vec();
130 let mut indices: Vec<usize> = (0..points.len()).collect();
131
132 let leaf_size = 32; let root = Self::build_tree(&points, &mut indices, 0, dimension, leaf_size);
134
135 Ok(Self {
136 root,
137 points,
138 dimension,
139 leaf_size: 32, })
141 }
142
143 fn build_tree(
145 points: &[P],
146 indices: &mut [usize],
147 depth: usize,
148 dimension: usize,
149 leaf_size: usize,
150 ) -> Option<Box<KDNode<T, P>>> {
151 if indices.is_empty() {
152 return None;
153 }
154
155 if indices.len() <= leaf_size {
157 let point_index = indices[0];
160 return Some(Box::new(KDNode {
161 point_index,
162 splitting_dimension: depth % dimension,
163 left: None,
164 right: None,
165 _phantom: PhantomData,
166 }));
167 }
168
169 let splitting_dimension = depth % dimension;
170
171 indices.sort_by(|&a, &b| {
173 let coord_a = points[a]
174 .coordinate(splitting_dimension)
175 .unwrap_or(T::zero());
176 let coord_b = points[b]
177 .coordinate(splitting_dimension)
178 .unwrap_or(T::zero());
179 coord_a.partial_cmp(&coord_b).unwrap_or(Ordering::Equal)
180 });
181
182 let median = indices.len() / 2;
183 let point_index = indices[median];
184
185 let (left_indices, right_indices) = indices.split_at_mut(median);
186 let right_indices = &mut right_indices[1..]; let left = Self::build_tree(points, left_indices, depth + 1, dimension, leaf_size);
189 let right = Self::build_tree(points, right_indices, depth + 1, dimension, leaf_size);
190
191 Some(Box::new(KDNode {
192 point_index,
193 splitting_dimension,
194 left,
195 right,
196 _phantom: PhantomData,
197 }))
198 }
199
200 pub fn k_nearest_neighbors(
202 &self,
203 query: &P,
204 k: usize,
205 metric: &dyn DistanceMetric<T, P>,
206 ) -> SpatialResult<Vec<(usize, T)>> {
207 if k == 0 {
208 return Ok(Vec::new());
209 }
210
211 if k > self.points.len() {
212 return Err(SpatialError::ValueError(format!(
213 "k ({}) cannot be larger than the number of points ({})",
214 k,
215 self.points.len()
216 )));
217 }
218
219 if k > 1000 {
220 return Err(SpatialError::ValueError(format!(
221 "k ({k}) is too large. Consider using radius search for k > 1000"
222 )));
223 }
224
225 if query.dimension() != self.dimension {
226 return Err(SpatialError::ValueError(format!(
227 "Query point dimension ({}) must match tree dimension ({})",
228 query.dimension(),
229 self.dimension
230 )));
231 }
232
233 for d in 0..query.dimension() {
235 if let Some(coord) = query.coordinate(d) {
236 if !Float::is_finite(coord) {
237 return Err(SpatialError::ValueError(format!(
238 "Query point has invalid coordinate {} at dimension {}",
239 NumCast::from(coord).unwrap_or(f64::NAN),
240 d
241 )));
242 }
243 }
244 }
245
246 if self.points.is_empty() {
247 return Ok(Vec::new());
248 }
249
250 let mut heap = BinaryHeap::new();
251
252 if let Some(ref root) = self.root {
253 self.search_knn(root, query, k, &mut heap, metric);
254 }
255
256 let mut result: Vec<(usize, T)> = heap
257 .into_sorted_vec()
258 .into_iter()
259 .map(|item| (item.index, item.distance))
260 .collect();
261
262 result.reverse(); Ok(result)
264 }
265
266 fn search_knn(
268 &self,
269 node: &KDNode<T, P>,
270 query: &P,
271 k: usize,
272 heap: &mut BinaryHeap<KNNItem<T>>,
273 metric: &dyn DistanceMetric<T, P>,
274 ) {
275 let point = &self.points[node.point_index];
276 let distance = metric.distance(query, point);
277
278 if heap.len() < k {
279 heap.push(KNNItem {
280 distance,
281 index: node.point_index,
282 });
283 } else if let Some(top) = heap.peek() {
284 if distance < top.distance {
285 heap.pop();
286 heap.push(KNNItem {
287 distance,
288 index: node.point_index,
289 });
290 }
291 }
292
293 let query_coord = query
295 .coordinate(node.splitting_dimension)
296 .unwrap_or(T::zero());
297 let point_coord = point
298 .coordinate(node.splitting_dimension)
299 .unwrap_or(T::zero());
300
301 let (first_child, second_child) = if query_coord < point_coord {
302 (&node.left, &node.right)
303 } else {
304 (&node.right, &node.left)
305 };
306
307 if let Some(ref child) = first_child {
309 self.search_knn(child, query, k, heap, metric);
310 }
311
312 let dimension_distance = (query_coord - point_coord).abs();
314 let should_search_other = heap.len() < k
315 || heap
316 .peek()
317 .is_none_or(|top| dimension_distance < top.distance);
318
319 if should_search_other {
320 if let Some(ref child) = second_child {
321 self.search_knn(child, query, k, heap, metric);
322 }
323 }
324 }
325
326 pub fn radius_search(
328 &self,
329 query: &P,
330 radius: T,
331 metric: &dyn DistanceMetric<T, P>,
332 ) -> SpatialResult<Vec<(usize, T)>> {
333 if query.dimension() != self.dimension {
334 return Err(SpatialError::ValueError(
335 "Query point dimension must match tree dimension".to_string(),
336 ));
337 }
338
339 let mut result = Vec::new();
340
341 if let Some(ref root) = self.root {
342 self.search_radius(root, query, radius, &mut result, metric);
343 }
344
345 Ok(result)
346 }
347
348 fn search_radius(
350 &self,
351 node: &KDNode<T, P>,
352 query: &P,
353 radius: T,
354 result: &mut Vec<(usize, T)>,
355 metric: &dyn DistanceMetric<T, P>,
356 ) {
357 let point = &self.points[node.point_index];
358 let distance = metric.distance(query, point);
359
360 if distance <= radius {
361 result.push((node.point_index, distance));
362 }
363
364 let query_coord = query
365 .coordinate(node.splitting_dimension)
366 .unwrap_or(T::zero());
367 let point_coord = point
368 .coordinate(node.splitting_dimension)
369 .unwrap_or(T::zero());
370 let _dimension_distance = (query_coord - point_coord).abs();
371
372 if let Some(ref left) = node.left {
374 if query_coord - radius <= point_coord {
375 self.search_radius(left, query, radius, result, metric);
376 }
377 }
378
379 if let Some(ref right) = node.right {
381 if query_coord + radius >= point_coord {
382 self.search_radius(right, query, radius, result, metric);
383 }
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
390struct KNNItem<T: SpatialScalar> {
391 distance: T,
392 index: usize,
393}
394
395impl<T: SpatialScalar> PartialEq for KNNItem<T> {
396 fn eq(&self, other: &Self) -> bool {
397 self.distance == other.distance
398 }
399}
400
401impl<T: SpatialScalar> Eq for KNNItem<T> {}
402
403impl<T: SpatialScalar> PartialOrd for KNNItem<T> {
404 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
405 Some(self.cmp(other))
406 }
407}
408
409impl<T: SpatialScalar> Ord for KNNItem<T> {
410 fn cmp(&self, other: &Self) -> Ordering {
411 self.distance
412 .partial_cmp(&other.distance)
413 .unwrap_or(Ordering::Equal)
414 }
415}
416
417pub struct GenericDistanceMatrix;
419
420impl GenericDistanceMatrix {
421 pub fn compute<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<Vec<T>>>
423 where
424 T: SpatialScalar + Send + Sync,
425 P: SpatialPoint<T> + Send + Sync,
426 M: DistanceMetric<T, P> + Send + Sync,
427 {
428 let n = points.len();
429
430 if n > 100 {
432 Self::compute_simd_optimized(points, metric)
433 } else {
434 Self::compute_basic(points, metric)
435 }
436 }
437
438 pub fn compute_flat<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<T>>
440 where
441 T: SpatialScalar + Send + Sync,
442 P: SpatialPoint<T> + Send + Sync,
443 M: DistanceMetric<T, P> + Send + Sync,
444 {
445 let n = points.len();
446 let mut matrix = vec![T::zero(); n * n];
447
448 for i in 0..n {
450 matrix[i * n + i] = T::zero();
452
453 let remaining = n - i - 1;
455 let j_chunks = remaining / 4;
456
457 for chunk in 0..j_chunks {
459 let j_base = i + 1 + chunk * 4;
460
461 let j0 = j_base;
462 let j1 = j_base + 1;
463 let j2 = j_base + 2;
464 let j3 = j_base + 3;
465
466 let distance0 = metric.distance(&points[i], &points[j0]);
468 let distance1 = metric.distance(&points[i], &points[j1]);
469 let distance2 = metric.distance(&points[i], &points[j2]);
470 let distance3 = metric.distance(&points[i], &points[j3]);
471
472 matrix[i * n + j0] = distance0;
474 matrix[j0 * n + i] = distance0;
475 matrix[i * n + j1] = distance1;
476 matrix[j1 * n + i] = distance1;
477 matrix[i * n + j2] = distance2;
478 matrix[j2 * n + i] = distance2;
479 matrix[i * n + j3] = distance3;
480 matrix[j3 * n + i] = distance3;
481 }
482
483 for j in (i + 1 + j_chunks * 4)..n {
485 let distance = metric.distance(&points[i], &points[j]);
486 matrix[i * n + j] = distance;
487 matrix[j * n + i] = distance; }
489 }
490
491 Ok(matrix)
492 }
493
494 fn compute_basic<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<Vec<T>>>
496 where
497 T: SpatialScalar,
498 P: SpatialPoint<T>,
499 M: DistanceMetric<T, P>,
500 {
501 let n = points.len();
502 let mut matrix = vec![vec![T::zero(); n]; n];
503
504 for i in 0..n {
506 matrix[i][i] = T::zero();
508
509 let remaining = n - i - 1;
511 let j_chunks = remaining / 4;
512
513 for chunk in 0..j_chunks {
515 let j_base = i + 1 + chunk * 4;
516
517 let j0 = j_base;
518 let j1 = j_base + 1;
519 let j2 = j_base + 2;
520 let j3 = j_base + 3;
521
522 let distance0 = metric.distance(&points[i], &points[j0]);
524 let distance1 = metric.distance(&points[i], &points[j1]);
525 let distance2 = metric.distance(&points[i], &points[j2]);
526 let distance3 = metric.distance(&points[i], &points[j3]);
527
528 matrix[i][j0] = distance0;
530 matrix[j0][i] = distance0;
531 matrix[i][j1] = distance1;
532 matrix[j1][i] = distance1;
533 matrix[i][j2] = distance2;
534 matrix[j2][i] = distance2;
535 matrix[i][j3] = distance3;
536 matrix[j3][i] = distance3;
537 }
538
539 for j in (i + 1 + j_chunks * 4)..n {
541 let distance = metric.distance(&points[i], &points[j]);
542 matrix[i][j] = distance;
543 matrix[j][i] = distance;
544 }
545 }
546
547 Ok(matrix)
548 }
549
550 fn compute_simd_optimized<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<Vec<T>>>
552 where
553 T: SpatialScalar + Send + Sync,
554 P: SpatialPoint<T> + Send + Sync,
555 M: DistanceMetric<T, P> + Send + Sync,
556 {
557 use scirs2_core::simd_ops::PlatformCapabilities;
558
559 let n = points.len();
560 let mut matrix = vec![vec![T::zero(); n]; n];
561 let caps = PlatformCapabilities::detect();
562
563 const SIMD_CHUNK_SIZE: usize = 4; if caps.simd_available {
567 for i in 0..n {
569 let point_i = &points[i];
570
571 let mut j = i + 1;
573 while j < n {
574 let chunk_end = (j + SIMD_CHUNK_SIZE).min(n);
575
576 if let Some(dimension) = Self::get_dimension(point_i) {
578 if dimension <= 4 {
579 Self::compute_simd_chunk(
581 &mut matrix,
582 i,
583 j,
584 chunk_end,
585 points,
586 metric,
587 dimension,
588 );
589 } else {
590 for k in j..chunk_end {
592 let distance = metric.distance(point_i, &points[k]);
593 matrix[i][k] = distance;
594 matrix[k][i] = distance;
595 }
596 }
597 } else {
598 for k in j..chunk_end {
600 let distance = metric.distance(point_i, &points[k]);
601 matrix[i][k] = distance;
602 matrix[k][i] = distance;
603 }
604 }
605
606 j = chunk_end;
607 }
608 }
609 } else {
610 return Self::compute_basic(points, metric);
612 }
613
614 Ok(matrix)
615 }
616
617 fn get_dimension<T, P>(point: &P) -> Option<usize>
619 where
620 T: SpatialScalar,
621 P: SpatialPoint<T>,
622 {
623 let dim = point.dimension();
624 if dim > 0 && dim <= 4 {
625 Some(dim)
626 } else {
627 None
628 }
629 }
630
631 fn compute_simd_chunk<T, P, M>(
633 matrix: &mut [Vec<T>],
634 i: usize,
635 j_start: usize,
636 j_end: usize,
637 points: &[P],
638 metric: &M,
639 dimension: usize,
640 ) where
641 T: SpatialScalar,
642 P: SpatialPoint<T>,
643 M: DistanceMetric<T, P>,
644 {
645 let point_i = &points[i];
646
647 match dimension {
649 2 => {
650 let xi = point_i.coordinate(0).unwrap_or(T::zero());
652 let yi = point_i.coordinate(1).unwrap_or(T::zero());
653
654 for k in j_start..j_end {
655 let point_k = &points[k];
656 let xk = point_k.coordinate(0).unwrap_or(T::zero());
657 let yk = point_k.coordinate(1).unwrap_or(T::zero());
658
659 let dx = xi - xk;
661 let dy = yi - yk;
662 let distance_sq = dx * dx + dy * dy;
663 let distance = distance_sq.sqrt();
664
665 matrix[i][k] = distance;
666 matrix[k][i] = distance;
667 }
668 }
669 3 => {
670 let xi = point_i.coordinate(0).unwrap_or(T::zero());
672 let yi = point_i.coordinate(1).unwrap_or(T::zero());
673 let zi = point_i.coordinate(2).unwrap_or(T::zero());
674
675 for k in j_start..j_end {
676 let point_k = &points[k];
677 let xk = point_k.coordinate(0).unwrap_or(T::zero());
678 let yk = point_k.coordinate(1).unwrap_or(T::zero());
679 let zk = point_k.coordinate(2).unwrap_or(T::zero());
680
681 let dx = xi - xk;
682 let dy = yi - yk;
683 let dz = zi - zk;
684 let distance_sq = dx * dx + dy * dy + dz * dz;
685 let distance = distance_sq.sqrt();
686
687 matrix[i][k] = distance;
688 matrix[k][i] = distance;
689 }
690 }
691 _ => {
692 for k in j_start..j_end {
694 let distance = metric.distance(point_i, &points[k]);
695 matrix[i][k] = distance;
696 matrix[k][i] = distance;
697 }
698 }
699 }
700 }
701
702 pub fn compute_parallel<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<Vec<T>>>
704 where
705 T: SpatialScalar + Send + Sync,
706 P: SpatialPoint<T> + Send + Sync + Clone,
707 M: DistanceMetric<T, P> + Send + Sync,
708 {
709 let n = points.len();
710
711 if n > 1000 {
713 Self::compute_parallel_memory_efficient(points, metric)
714 } else {
715 Self::compute_parallel_basic(points, metric)
716 }
717 }
718
719 fn compute_parallel_basic<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<Vec<T>>>
721 where
722 T: SpatialScalar + Send + Sync,
723 P: SpatialPoint<T> + Send + Sync + Clone,
724 M: DistanceMetric<T, P> + Send + Sync,
725 {
726 let n = points.len();
727 let mut matrix = vec![vec![T::zero(); n]; n];
728 let metric = Arc::new(metric);
729 let points = Arc::new(points);
730
731 let indices: Vec<(usize, usize)> =
733 (0..n).flat_map(|i| (i..n).map(move |j| (i, j))).collect();
734
735 let distances: Vec<T> = indices
736 .par_iter()
737 .map(|&(i, j)| {
738 if i == j {
739 T::zero()
740 } else {
741 metric.distance(&points[i], &points[j])
742 }
743 })
744 .collect();
745
746 for (idx, &(i, j)) in indices.iter().enumerate() {
748 matrix[i][j] = distances[idx];
749 matrix[j][i] = distances[idx];
750 }
751
752 Ok(matrix)
753 }
754
755 fn compute_parallel_memory_efficient<T, P, M>(
757 points: &[P],
758 metric: &M,
759 ) -> SpatialResult<Vec<Vec<T>>>
760 where
761 T: SpatialScalar + Send + Sync,
762 P: SpatialPoint<T> + Send + Sync + Clone,
763 M: DistanceMetric<T, P> + Send + Sync,
764 {
765 let n = points.len();
766 let mut matrix = vec![vec![T::zero(); n]; n];
767
768 const PARALLEL_CHUNK_SIZE: usize = 64; let chunks: Vec<Vec<usize>> = (0..n)
772 .collect::<Vec<_>>()
773 .chunks(PARALLEL_CHUNK_SIZE)
774 .map(|chunk| chunk.to_vec())
775 .collect();
776
777 chunks.par_iter().for_each(|chunk_indices| {
779 let mut local_distances = vec![T::zero(); n];
781
782 for &i in chunk_indices {
783 local_distances.fill(T::zero());
785
786 if points[i].dimension() <= 4 {
788 Self::compute_row_distances_simd(
789 &points[i],
790 points,
791 &mut local_distances,
792 metric,
793 );
794 } else {
795 Self::compute_row_distances_scalar(
796 &points[i],
797 points,
798 &mut local_distances,
799 metric,
800 );
801 }
802
803 unsafe {
805 let matrix_ptr = matrix.as_ptr() as *mut Vec<T>;
806 let row_ptr = (*matrix_ptr.add(i)).as_mut_ptr();
807 std::ptr::copy_nonoverlapping(local_distances.as_ptr(), row_ptr, n);
808 }
809 }
810 });
811
812 for i in 0..n {
814 for j in (i + 1)..n {
815 matrix[j][i] = matrix[i][j];
816 }
817 }
818
819 Ok(matrix)
820 }
821
822 fn compute_row_distances_simd<T, P, M>(
824 point_i: &P,
825 points: &[P],
826 distances: &mut [T],
827 metric: &M,
828 ) where
829 T: SpatialScalar,
830 P: SpatialPoint<T>,
831 M: DistanceMetric<T, P>,
832 {
833 match point_i.dimension() {
834 2 => {
835 let xi = point_i.coordinate(0).unwrap_or(T::zero());
836 let yi = point_i.coordinate(1).unwrap_or(T::zero());
837
838 for (j, point_j) in points.iter().enumerate() {
840 let xj = point_j.coordinate(0).unwrap_or(T::zero());
841 let yj = point_j.coordinate(1).unwrap_or(T::zero());
842
843 let dx = xi - xj;
844 let dy = yi - yj;
845 distances[j] = (dx * dx + dy * dy).sqrt();
846 }
847 }
848 3 => {
849 let xi = point_i.coordinate(0).unwrap_or(T::zero());
850 let yi = point_i.coordinate(1).unwrap_or(T::zero());
851 let zi = point_i.coordinate(2).unwrap_or(T::zero());
852
853 for (j, point_j) in points.iter().enumerate() {
854 let xj = point_j.coordinate(0).unwrap_or(T::zero());
855 let yj = point_j.coordinate(1).unwrap_or(T::zero());
856 let zj = point_j.coordinate(2).unwrap_or(T::zero());
857
858 let dx = xi - xj;
859 let dy = yi - yj;
860 let dz = zi - zj;
861 distances[j] = (dx * dx + dy * dy + dz * dz).sqrt();
862 }
863 }
864 _ => {
865 Self::compute_row_distances_scalar(point_i, points, distances, metric);
866 }
867 }
868 }
869
870 fn compute_row_distances_scalar<T, P, M>(
872 point_i: &P,
873 points: &[P],
874 distances: &mut [T],
875 metric: &M,
876 ) where
877 T: SpatialScalar,
878 P: SpatialPoint<T>,
879 M: DistanceMetric<T, P>,
880 {
881 for (j, point_j) in points.iter().enumerate() {
882 distances[j] = metric.distance(point_i, point_j);
883 }
884 }
885
886 pub fn compute_condensed<T, P, M>(points: &[P], metric: &M) -> SpatialResult<Vec<T>>
888 where
889 T: SpatialScalar,
890 P: SpatialPoint<T>,
891 M: DistanceMetric<T, P>,
892 {
893 let n = points.len();
894 let mut distances = Vec::with_capacity(n * (n - 1) / 2);
895
896 for i in 0..n {
897 for j in (i + 1)..n {
898 distances.push(metric.distance(&points[i], &points[j]));
899 }
900 }
901
902 Ok(distances)
903 }
904}
905
906pub struct GenericKMeans<T: SpatialScalar, P: SpatialPoint<T>> {
908 k: usize,
909 max_iterations: usize,
910 tolerance: T,
911 parallel: bool,
912 phantom: PhantomData<(T, P)>,
913}
914
915impl<T: SpatialScalar, P: SpatialPoint<T> + Clone> GenericKMeans<T, P> {
916 pub fn new(k: usize) -> Self {
918 Self {
919 k,
920 max_iterations: 5, tolerance: T::from_f64(1e-1).unwrap_or(<T as SpatialScalar>::epsilon()), parallel: false,
923 phantom: PhantomData,
924 }
925 }
926
927 pub fn with_parallel(mut self, parallel: bool) -> Self {
929 self.parallel = parallel;
930 self
931 }
932
933 pub fn with_max_iterations(mut self, maxiterations: usize) -> Self {
935 self.max_iterations = maxiterations;
936 self
937 }
938
939 pub fn with_tolerance(mut self, tolerance: T) -> Self {
941 self.tolerance = tolerance;
942 self
943 }
944
945 pub fn fit(&self, points: &[P]) -> SpatialResult<KMeansResult<T, P>> {
947 if points.is_empty() {
948 return Err(SpatialError::ValueError(
949 "Cannot cluster empty point set".to_string(),
950 ));
951 }
952
953 if self.k == 0 {
954 return Err(SpatialError::ValueError(
955 "k must be greater than 0".to_string(),
956 ));
957 }
958
959 if self.k > points.len() {
960 return Err(SpatialError::ValueError(format!(
961 "k ({}) cannot be larger than the number of points ({})",
962 self.k,
963 points.len()
964 )));
965 }
966
967 if self.k > 10000 {
968 return Err(SpatialError::ValueError(format!(
969 "k ({}) is too large. Consider using hierarchical clustering for k > 10000",
970 self.k
971 )));
972 }
973
974 let dimension = points[0].dimension();
975
976 if dimension == 0 {
977 return Err(SpatialError::ValueError(
978 "Points must have at least one dimension".to_string(),
979 ));
980 }
981
982 for (i, point) in points.iter().enumerate() {
984 if point.dimension() != dimension {
985 return Err(SpatialError::ValueError(format!(
986 "Point {} has dimension {} but expected {}",
987 i,
988 point.dimension(),
989 dimension
990 )));
991 }
992
993 for d in 0..dimension {
995 if let Some(coord) = point.coordinate(d) {
996 if !Float::is_finite(coord) {
997 return Err(SpatialError::ValueError(format!(
998 "Point {} has invalid coordinate {} at dimension {}",
999 i,
1000 NumCast::from(coord).unwrap_or(f64::NAN),
1001 d
1002 )));
1003 }
1004 }
1005 }
1006 }
1007
1008 let mut centroids = self.initialize_centroids(points, dimension)?;
1010 let mut assignments = vec![0; points.len()];
1011
1012 let mut point_distances = vec![T::zero(); self.k];
1014
1015 for iteration in 0..self.max_iterations {
1016 let mut changed = false;
1017
1018 const CHUNK_SIZE: usize = 16; let chunks = points.chunks(CHUNK_SIZE);
1021
1022 for (chunk_start, chunk) in chunks.enumerate() {
1023 let chunk_offset = chunk_start * CHUNK_SIZE;
1024
1025 if self.parallel && points.len() > 10000 {
1026 for (local_i, point) in chunk.iter().enumerate() {
1029 let i = chunk_offset + local_i;
1030 let mut best_cluster = 0;
1031 let mut best_distance = T::max_finite();
1032
1033 self.compute_distances_simd(point, ¢roids, &mut point_distances);
1035
1036 for (j, &distance) in point_distances.iter().enumerate() {
1038 if distance < best_distance {
1039 best_distance = distance;
1040 best_cluster = j;
1041 }
1042 }
1043
1044 if assignments[i] != best_cluster {
1045 assignments[i] = best_cluster;
1046 changed = true;
1047 }
1048 }
1049 } else {
1050 for (local_i, point) in chunk.iter().enumerate() {
1052 let i = chunk_offset + local_i;
1053 let mut best_cluster = 0;
1054 let mut best_distance = T::max_finite();
1055
1056 self.compute_distances_simd(point, ¢roids, &mut point_distances);
1058
1059 for (j, &distance) in point_distances.iter().enumerate() {
1061 if distance < best_distance {
1062 best_distance = distance;
1063 best_cluster = j;
1064 }
1065 }
1066
1067 if assignments[i] != best_cluster {
1068 assignments[i] = best_cluster;
1069 changed = true;
1070 }
1071 }
1072 }
1073 }
1074
1075 let old_centroids = centroids.clone();
1077 centroids = self.update_centroids(points, &assignments, dimension)?;
1078
1079 let max_movement = old_centroids
1081 .iter()
1082 .zip(centroids.iter())
1083 .map(|(old, new)| old.distance_to(new))
1084 .fold(T::zero(), |acc, dist| if dist > acc { dist } else { acc });
1085
1086 if !changed || max_movement < self.tolerance {
1087 return Ok(KMeansResult {
1088 centroids,
1089 assignments,
1090 iterations: iteration + 1,
1091 converged: max_movement < self.tolerance,
1092 phantom: PhantomData,
1093 });
1094 }
1095 }
1096
1097 Ok(KMeansResult {
1098 centroids,
1099 assignments,
1100 iterations: self.max_iterations,
1101 converged: false,
1102 phantom: PhantomData,
1103 })
1104 }
1105
1106 fn initialize_centroids(
1108 &self,
1109 points: &[P],
1110 _dimension: usize,
1111 ) -> SpatialResult<Vec<Point<T>>> {
1112 let mut centroids = Vec::with_capacity(self.k);
1113
1114 centroids.push(GenericKMeans::<T, P>::point_to_generic(&points[0]));
1116
1117 for _ in 1..self.k {
1119 let mut distances = Vec::with_capacity(points.len());
1120
1121 for point in points {
1122 let min_distance = centroids
1123 .iter()
1124 .map(|centroid| {
1125 GenericKMeans::<T, P>::point_to_generic(point).distance_to(centroid)
1126 })
1127 .fold(
1128 T::max_finite(),
1129 |acc, dist| if dist < acc { dist } else { acc },
1130 );
1131 distances.push(min_distance);
1132 }
1133
1134 let max_distance_idx = distances
1136 .iter()
1137 .enumerate()
1138 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
1139 .map(|(idx_, _)| idx_)
1140 .unwrap_or(0);
1141
1142 centroids.push(GenericKMeans::<T, P>::point_to_generic(
1143 &points[max_distance_idx],
1144 ));
1145 }
1146
1147 Ok(centroids)
1148 }
1149
1150 fn update_centroids(
1152 &self,
1153 points: &[P],
1154 assignments: &[usize],
1155 dimension: usize,
1156 ) -> SpatialResult<Vec<Point<T>>> {
1157 let mut centroids = vec![Point::zeros(dimension); self.k];
1158 let mut counts = vec![0; self.k];
1159
1160 const UPDATE_CHUNK_SIZE: usize = 512;
1162 for chunk in points.chunks(UPDATE_CHUNK_SIZE) {
1163 let assignments_chunk = &assignments[..chunk.len().min(assignments.len())];
1164
1165 for (point, &cluster) in chunk.iter().zip(assignments_chunk.iter()) {
1166 let point_coords: Vec<T> = (0..dimension)
1168 .map(|d| point.coordinate(d).unwrap_or(T::zero()))
1169 .collect();
1170
1171 for (d, &coord) in point_coords.iter().enumerate() {
1172 if let Some(centroid_coord) = centroids[cluster].coords_mut().get_mut(d) {
1173 *centroid_coord = *centroid_coord + coord;
1174 }
1175 }
1176 counts[cluster] += 1;
1177 }
1178 }
1179
1180 for (centroid, count) in centroids.iter_mut().zip(counts.iter()) {
1182 if *count > 0 {
1183 let count_scalar = T::from(*count).unwrap_or(T::one());
1184 for coord in centroid.coords_mut() {
1186 *coord = *coord / count_scalar;
1187 }
1188 }
1189 }
1190
1191 Ok(centroids)
1192 }
1193
1194 fn point_to_generic(point: &P) -> Point<T> {
1196 let coords: Vec<T> = (0..point.dimension())
1197 .map(|i| point.coordinate(i).unwrap_or(T::zero()))
1198 .collect();
1199 Point::new(coords)
1200 }
1201
1202 fn compute_distances_simd(&self, point: &P, centroids: &[Point<T>], distances: &mut [T]) {
1204 let _caps = PlatformCapabilities::detect();
1205 let point_generic = GenericKMeans::<T, P>::point_to_generic(point);
1206
1207 for (j, centroid) in centroids.iter().enumerate() {
1209 distances[j] = point_generic.distance_to(centroid);
1210 }
1211 }
1212
1213 #[allow(dead_code)]
1215 fn compute_distances_simd_optimized(
1216 &self,
1217 point: &Point<T>,
1218 centroids: &[Point<T>],
1219 distances: &mut [T],
1220 ) {
1221 match point.dimension() {
1222 2 => {
1223 let px = point.coordinate(0).unwrap_or(T::zero());
1225 let py = point.coordinate(1).unwrap_or(T::zero());
1226
1227 let mut i = 0;
1229 while i + 3 < centroids.len() {
1230 for j in 0..4 {
1232 if i + j < centroids.len() {
1233 let centroid = ¢roids[i + j];
1234 let cx = centroid.coordinate(0).unwrap_or(T::zero());
1235 let cy = centroid.coordinate(1).unwrap_or(T::zero());
1236
1237 let dx = px - cx;
1238 let dy = py - cy;
1239 distances[i + j] = (dx * dx + dy * dy).sqrt();
1240 }
1241 }
1242 i += 4;
1243 }
1244
1245 while i < centroids.len() {
1247 let centroid = ¢roids[i];
1248 let cx = centroid.coordinate(0).unwrap_or(T::zero());
1249 let cy = centroid.coordinate(1).unwrap_or(T::zero());
1250
1251 let dx = px - cx;
1252 let dy = py - cy;
1253 distances[i] = (dx * dx + dy * dy).sqrt();
1254 i += 1;
1255 }
1256 }
1257 3 => {
1258 let px = point.coordinate(0).unwrap_or(T::zero());
1260 let py = point.coordinate(1).unwrap_or(T::zero());
1261 let pz = point.coordinate(2).unwrap_or(T::zero());
1262
1263 for (i, centroid) in centroids.iter().enumerate() {
1264 let cx = centroid.coordinate(0).unwrap_or(T::zero());
1265 let cy = centroid.coordinate(1).unwrap_or(T::zero());
1266 let cz = centroid.coordinate(2).unwrap_or(T::zero());
1267
1268 let dx = px - cx;
1269 let dy = py - cy;
1270 let dz = pz - cz;
1271 distances[i] = (dx * dx + dy * dy + dz * dz).sqrt();
1272 }
1273 }
1274 _ => {
1275 for (j, centroid) in centroids.iter().enumerate() {
1277 distances[j] = point.distance_to(centroid);
1278 }
1279 }
1280 }
1281 }
1282}
1283
1284#[derive(Debug, Clone)]
1286pub struct KMeansResult<T: SpatialScalar, P: SpatialPoint<T>> {
1287 pub centroids: Vec<Point<T>>,
1289 pub assignments: Vec<usize>,
1291 pub iterations: usize,
1293 pub converged: bool,
1295 phantom: PhantomData<P>,
1296}
1297
1298pub struct GenericConvexHull;
1300
1301impl GenericConvexHull {
1302 pub fn graham_scan_2d<T, P>(points: &[P]) -> SpatialResult<Vec<Point<T>>>
1304 where
1305 T: SpatialScalar,
1306 P: SpatialPoint<T> + Clone,
1307 {
1308 if points.is_empty() {
1309 return Ok(Vec::new());
1310 }
1311
1312 if points.len() < 3 {
1313 return Ok(points.iter().map(|p| Self::to_generic_point(p)).collect());
1314 }
1315
1316 for point in points {
1318 if point.dimension() != 2 {
1319 return Err(SpatialError::ValueError(
1320 "All points must be 2D for 2D convex hull".to_string(),
1321 ));
1322 }
1323 }
1324
1325 let mut generic_points: Vec<Point<T>> =
1326 points.iter().map(|p| Self::to_generic_point(p)).collect();
1327
1328 let start_idx = generic_points
1330 .iter()
1331 .enumerate()
1332 .min_by(|(_, a), (_, b)| {
1333 let y_cmp = a
1334 .coordinate(1)
1335 .partial_cmp(&b.coordinate(1))
1336 .expect("Operation failed");
1337 if y_cmp == Ordering::Equal {
1338 a.coordinate(0)
1339 .partial_cmp(&b.coordinate(0))
1340 .expect("Operation failed")
1341 } else {
1342 y_cmp
1343 }
1344 })
1345 .map(|(idx_, _)| idx_)
1346 .expect("Operation failed");
1347
1348 generic_points.swap(0, start_idx);
1349 let start_point = generic_points[0].clone();
1350
1351 generic_points[1..].sort_by(|a, b| {
1353 let angle_a = Self::polar_angle(&start_point, a);
1354 let angle_b = Self::polar_angle(&start_point, b);
1355 angle_a.partial_cmp(&angle_b).unwrap_or(Ordering::Equal)
1356 });
1357
1358 let mut hull = Vec::new();
1360 for point in generic_points {
1361 while hull.len() > 1
1362 && Self::cross_product(&hull[hull.len() - 2], &hull[hull.len() - 1], &point)
1363 <= T::zero()
1364 {
1365 hull.pop();
1366 }
1367 hull.push(point);
1368 }
1369
1370 Ok(hull)
1371 }
1372
1373 fn to_generic_point<T, P>(point: &P) -> Point<T>
1375 where
1376 T: SpatialScalar,
1377 P: SpatialPoint<T>,
1378 {
1379 let coords: Vec<T> = (0..point.dimension())
1380 .map(|i| point.coordinate(i).unwrap_or(T::zero()))
1381 .collect();
1382 Point::new(coords)
1383 }
1384
1385 fn polar_angle<T: SpatialScalar>(start: &Point<T>, point: &Point<T>) -> T {
1387 let dx =
1388 point.coordinate(0).unwrap_or(T::zero()) - start.coordinate(0).unwrap_or(T::zero());
1389 let dy =
1390 point.coordinate(1).unwrap_or(T::zero()) - start.coordinate(1).unwrap_or(T::zero());
1391 dy.atan2(dx)
1392 }
1393
1394 fn cross_product<T: SpatialScalar>(a: &Point<T>, b: &Point<T>, c: &Point<T>) -> T {
1396 let ab_x = b.coordinate(0).unwrap_or(T::zero()) - a.coordinate(0).unwrap_or(T::zero());
1397 let ab_y = b.coordinate(1).unwrap_or(T::zero()) - a.coordinate(1).unwrap_or(T::zero());
1398 let ac_x = c.coordinate(0).unwrap_or(T::zero()) - a.coordinate(0).unwrap_or(T::zero());
1399 let ac_y = c.coordinate(1).unwrap_or(T::zero()) - a.coordinate(1).unwrap_or(T::zero());
1400
1401 ab_x * ac_y - ab_y * ac_x
1402 }
1403}
1404
1405pub struct GenericDBSCAN<T: SpatialScalar> {
1407 eps: T,
1408 minsamples: usize,
1409 _phantom: PhantomData<T>,
1410}
1411
1412impl<T: SpatialScalar> GenericDBSCAN<T> {
1413 pub fn new(_eps: T, minsamples: usize) -> Self {
1415 Self {
1416 eps: _eps,
1417 minsamples,
1418 _phantom: PhantomData,
1419 }
1420 }
1421
1422 pub fn fit<P, M>(&self, points: &[P], metric: &M) -> SpatialResult<DBSCANResult>
1424 where
1425 P: SpatialPoint<T>,
1426 M: DistanceMetric<T, P>,
1427 {
1428 if points.is_empty() {
1429 return Ok(DBSCANResult {
1430 labels: Vec::new(),
1431 n_clusters: 0,
1432 });
1433 }
1434
1435 if self.minsamples == 0 {
1436 return Err(SpatialError::ValueError(
1437 "minsamples must be greater than 0".to_string(),
1438 ));
1439 }
1440
1441 if self.minsamples > points.len() {
1442 return Err(SpatialError::ValueError(format!(
1443 "minsamples ({}) cannot be larger than the number of points ({})",
1444 self.minsamples,
1445 points.len()
1446 )));
1447 }
1448
1449 if !Float::is_finite(self.eps) || self.eps <= T::zero() {
1450 return Err(SpatialError::ValueError(format!(
1451 "eps must be a positive finite number, got: {}",
1452 NumCast::from(self.eps).unwrap_or(f64::NAN)
1453 )));
1454 }
1455
1456 let dimension = if points.is_empty() {
1458 0
1459 } else {
1460 points[0].dimension()
1461 };
1462 for (i, point) in points.iter().enumerate() {
1463 if point.dimension() != dimension {
1464 return Err(SpatialError::ValueError(format!(
1465 "Point {} has dimension {} but expected {}",
1466 i,
1467 point.dimension(),
1468 dimension
1469 )));
1470 }
1471
1472 for d in 0..dimension {
1474 if let Some(coord) = point.coordinate(d) {
1475 if !Float::is_finite(coord) {
1476 return Err(SpatialError::ValueError(format!(
1477 "Point {} has invalid coordinate {} at dimension {}",
1478 i,
1479 NumCast::from(coord).unwrap_or(f64::NAN),
1480 d
1481 )));
1482 }
1483 }
1484 }
1485 }
1486
1487 let n = points.len();
1488 let mut labels = vec![-1i32; n]; let mut visited = vec![false; n];
1490 let mut cluster_id = 0;
1491
1492 const DBSCAN_PROCESS_CHUNK_SIZE: usize = 32; for chunk_start in (0..n).step_by(DBSCAN_PROCESS_CHUNK_SIZE) {
1496 let chunk_end = (chunk_start + DBSCAN_PROCESS_CHUNK_SIZE).min(n);
1497
1498 for i in chunk_start..chunk_end {
1499 if visited[i] {
1500 continue;
1501 }
1502 visited[i] = true;
1503
1504 let neighbors = self.find_neighbors(points, i, metric);
1506
1507 if neighbors.len() < self.minsamples {
1508 labels[i] = -1; } else {
1510 self.expand_cluster(
1511 points,
1512 &mut labels,
1513 &mut visited,
1514 i,
1515 &neighbors,
1516 cluster_id,
1517 metric,
1518 );
1519 cluster_id += 1;
1520
1521 if cluster_id > 10000 {
1523 return Err(SpatialError::ValueError(
1524 format!("Too many clusters found: {cluster_id}. Consider adjusting eps or minsamples parameters")
1525 ));
1526 }
1527 }
1528 }
1529
1530 if chunk_start > 0 && chunk_start % (DBSCAN_PROCESS_CHUNK_SIZE * 10) == 0 {
1532 std::hint::black_box(&labels);
1534 std::hint::black_box(&visited);
1535 }
1536 }
1537
1538 Ok(DBSCANResult {
1539 labels,
1540 n_clusters: cluster_id,
1541 })
1542 }
1543
1544 fn find_neighbors<P, M>(&self, points: &[P], pointidx: usize, metric: &M) -> Vec<usize>
1546 where
1547 P: SpatialPoint<T>,
1548 M: DistanceMetric<T, P>,
1549 {
1550 let mut neighbors = Vec::with_capacity(32); let query_point = &points[pointidx];
1552 let _eps_squared = self.eps * self.eps; const NEIGHBOR_CHUNK_SIZE: usize = 16; if points.len() > 5000 {
1558 for chunk in points.chunks(NEIGHBOR_CHUNK_SIZE) {
1560 let chunk_start = ((chunk.as_ptr() as usize - points.as_ptr() as usize)
1561 / std::mem::size_of::<P>())
1562 .min(points.len());
1563
1564 for (local_idx, point) in chunk.iter().enumerate() {
1565 let global_idx = chunk_start + local_idx;
1566 if global_idx >= points.len() {
1567 break;
1568 }
1569
1570 let distance = metric.distance(query_point, point);
1572 if distance <= self.eps {
1573 neighbors.push(global_idx);
1574 }
1575 }
1576
1577 if neighbors.len() > 100 {
1579 break;
1580 }
1581 }
1582 } else {
1583 for (i, point) in points.iter().enumerate() {
1585 if metric.distance(query_point, point) <= self.eps {
1586 neighbors.push(i);
1587 }
1588 }
1589 }
1590
1591 neighbors.shrink_to_fit(); neighbors
1593 }
1594
1595 #[allow(clippy::too_many_arguments)]
1597 fn expand_cluster<P, M>(
1598 &self,
1599 points: &[P],
1600 labels: &mut [i32],
1601 visited: &mut [bool],
1602 pointidx: usize,
1603 neighbors: &[usize],
1604 cluster_id: i32,
1605 metric: &M,
1606 ) where
1607 P: SpatialPoint<T>,
1608 M: DistanceMetric<T, P>,
1609 {
1610 labels[pointidx] = cluster_id;
1611
1612 let mut processed = vec![false; points.len()];
1614 let mut seed_set = Vec::with_capacity(neighbors.len() * 2);
1615
1616 for &neighbor in neighbors {
1618 if neighbor < points.len() {
1619 seed_set.push(neighbor);
1620 }
1621 }
1622
1623 const EXPAND_BATCH_SIZE: usize = 32;
1625 let mut batch_buffer = Vec::with_capacity(EXPAND_BATCH_SIZE);
1626
1627 while !seed_set.is_empty() {
1628 let batch_size = seed_set.len().min(EXPAND_BATCH_SIZE);
1630 batch_buffer.clear();
1631 batch_buffer.extend(seed_set.drain(..batch_size));
1632
1633 for q in batch_buffer.iter().copied() {
1634 if q >= points.len() || processed[q] {
1635 continue;
1636 }
1637 processed[q] = true;
1638
1639 if !visited[q] {
1640 visited[q] = true;
1641 let q_neighbors = self.find_neighbors(points, q, metric);
1642
1643 if q_neighbors.len() >= self.minsamples {
1644 for &neighbor in &q_neighbors {
1646 if neighbor < points.len()
1647 && !processed[neighbor]
1648 && !seed_set.contains(&neighbor)
1649 {
1650 seed_set.push(neighbor);
1651 }
1652 }
1653 }
1654 }
1655
1656 if labels[q] == -1 {
1658 labels[q] = cluster_id;
1659 }
1660 }
1661
1662 if seed_set.len() > 1000 {
1664 seed_set.sort_unstable();
1665 seed_set.dedup();
1666 }
1667 }
1668 }
1669}
1670
1671#[derive(Debug, Clone)]
1673pub struct DBSCANResult {
1674 pub labels: Vec<i32>,
1676 pub n_clusters: i32,
1678}
1679
1680pub struct GenericGMM<T: SpatialScalar> {
1682 _ncomponents: usize,
1683 max_iterations: usize,
1684 tolerance: T,
1685 reg_covar: T,
1686 _phantom: PhantomData<T>,
1687}
1688
1689impl<T: SpatialScalar> GenericGMM<T> {
1690 pub fn new(_ncomponents: usize) -> Self {
1692 Self {
1693 _ncomponents,
1694 max_iterations: 3, tolerance: T::from_f64(1e-1).unwrap_or(<T as SpatialScalar>::epsilon()), reg_covar: T::from_f64(1e-6).unwrap_or(<T as SpatialScalar>::epsilon()),
1697 _phantom: PhantomData,
1698 }
1699 }
1700
1701 pub fn with_max_iterations(mut self, maxiterations: usize) -> Self {
1703 self.max_iterations = maxiterations;
1704 self
1705 }
1706
1707 pub fn with_tolerance(mut self, tolerance: T) -> Self {
1709 self.tolerance = tolerance;
1710 self
1711 }
1712
1713 pub fn with_reg_covar(mut self, regcovar: T) -> Self {
1715 self.reg_covar = regcovar;
1716 self
1717 }
1718
1719 #[allow(clippy::needless_range_loop)]
1721 pub fn fit<P>(&self, points: &[P]) -> SpatialResult<GMMResult<T>>
1722 where
1723 P: SpatialPoint<T> + Clone,
1724 {
1725 if points.is_empty() {
1726 return Err(SpatialError::ValueError(
1727 "Cannot fit GMM to empty dataset".to_string(),
1728 ));
1729 }
1730
1731 let n_samples = points.len();
1732 let n_features = points[0].dimension();
1733
1734 let kmeans = GenericKMeans::new(self._ncomponents);
1736 let kmeans_result = kmeans.fit(points)?;
1737
1738 let mut means = kmeans_result.centroids;
1740 let mut weights = vec![
1741 T::one() / T::from(self._ncomponents).expect("Operation failed");
1742 self._ncomponents
1743 ];
1744
1745 let mut covariances =
1747 vec![vec![vec![T::zero(); n_features]; n_features]; self._ncomponents];
1748
1749 for k in 0..self._ncomponents {
1751 let cluster_points: Vec<&P> = kmeans_result
1752 .assignments
1753 .iter()
1754 .enumerate()
1755 .filter_map(
1756 |(i, &cluster)| {
1757 if cluster == k {
1758 Some(&points[i])
1759 } else {
1760 None
1761 }
1762 },
1763 )
1764 .collect();
1765
1766 if !cluster_points.is_empty() {
1767 let cluster_mean = &means[k];
1768
1769 for i in 0..n_features {
1771 for j in 0..n_features {
1772 let mut cov_sum = T::zero();
1773 let count = T::from(cluster_points.len()).expect("Operation failed");
1774
1775 for point in &cluster_points {
1776 let pi = point.coordinate(i).unwrap_or(T::zero())
1777 - cluster_mean.coordinate(i).unwrap_or(T::zero());
1778 let pj = point.coordinate(j).unwrap_or(T::zero())
1779 - cluster_mean.coordinate(j).unwrap_or(T::zero());
1780 cov_sum = cov_sum + pi * pj;
1781 }
1782
1783 covariances[k][i][j] = if count > T::one() {
1784 cov_sum / (count - T::one())
1785 } else if i == j {
1786 T::one()
1787 } else {
1788 T::zero()
1789 };
1790 }
1791 }
1792
1793 for i in 0..n_features {
1795 covariances[k][i][i] = covariances[k][i][i] + self.reg_covar;
1796 }
1797 } else {
1798 for i in 0..n_features {
1800 covariances[k][i][i] = T::one();
1801 }
1802 }
1803 }
1804
1805 let mut log_likelihood = T::min_value();
1807 let mut responsibilities = vec![vec![T::zero(); self._ncomponents]; n_samples];
1808
1809 for iteration in 0..self.max_iterations {
1810 let mut new_log_likelihood = T::zero();
1812
1813 for i in 0..n_samples {
1814 let point = Self::point_to_generic(&points[i]);
1815 let mut log_likelihoods = vec![T::min_value(); self._ncomponents];
1816 let mut max_log_likelihood = T::min_value();
1817
1818 for k in 0..self._ncomponents {
1820 let log_weight = weights[k].ln();
1821 let log_gaussian = self.compute_log_gaussian_probability(
1822 &point,
1823 &means[k],
1824 &covariances[k],
1825 n_features,
1826 );
1827 log_likelihoods[k] = log_weight + log_gaussian;
1828 if log_likelihoods[k] > max_log_likelihood {
1829 max_log_likelihood = log_likelihoods[k];
1830 }
1831 }
1832
1833 let mut sum_exp = T::zero();
1835 for k in 0..self._ncomponents {
1836 let exp_val = (log_likelihoods[k] - max_log_likelihood).exp();
1837 responsibilities[i][k] = exp_val;
1838 sum_exp = sum_exp + exp_val;
1839 }
1840
1841 if sum_exp > T::zero() {
1843 for k in 0..self._ncomponents {
1844 responsibilities[i][k] = responsibilities[i][k] / sum_exp;
1845 }
1846 new_log_likelihood = new_log_likelihood + max_log_likelihood + sum_exp.ln();
1847 }
1848 }
1849
1850 let mut nk_values = vec![T::zero(); self._ncomponents];
1852
1853 for k in 0..self._ncomponents {
1855 let mut nk = T::zero();
1856 for i in 0..n_samples {
1857 nk = nk + responsibilities[i][k];
1858 }
1859 nk_values[k] = nk;
1860 weights[k] = nk / T::from(n_samples).expect("Operation failed");
1861 }
1862
1863 for k in 0..self._ncomponents {
1865 if nk_values[k] > T::zero() {
1866 let mut new_mean_coords = vec![T::zero(); n_features];
1867
1868 for i in 0..n_samples {
1869 let point = Self::point_to_generic(&points[i]);
1870 for d in 0..n_features {
1871 let coord = point.coordinate(d).unwrap_or(T::zero());
1872 new_mean_coords[d] =
1873 new_mean_coords[d] + responsibilities[i][k] * coord;
1874 }
1875 }
1876
1877 for d in 0..n_features {
1879 new_mean_coords[d] = new_mean_coords[d] / nk_values[k];
1880 }
1881
1882 means[k] = Point::new(new_mean_coords);
1883 }
1884 }
1885
1886 for k in 0..self._ncomponents {
1888 if nk_values[k] > T::one() {
1889 let mean_k = &means[k];
1890
1891 for i in 0..n_features {
1893 for j in 0..n_features {
1894 covariances[k][i][j] = T::zero();
1895 }
1896 }
1897
1898 for sample_idx in 0..n_samples {
1900 let point = Self::point_to_generic(&points[sample_idx]);
1901 let resp = responsibilities[sample_idx][k];
1902
1903 for i in 0..n_features {
1904 for j in 0..n_features {
1905 let diff_i = point.coordinate(i).unwrap_or(T::zero())
1906 - mean_k.coordinate(i).unwrap_or(T::zero());
1907 let diff_j = point.coordinate(j).unwrap_or(T::zero())
1908 - mean_k.coordinate(j).unwrap_or(T::zero());
1909 covariances[k][i][j] =
1910 covariances[k][i][j] + resp * diff_i * diff_j;
1911 }
1912 }
1913 }
1914
1915 for i in 0..n_features {
1917 for j in 0..n_features {
1918 covariances[k][i][j] = covariances[k][i][j] / nk_values[k];
1919 if i == j {
1920 covariances[k][i][j] = covariances[k][i][j] + self.reg_covar;
1921 }
1922 }
1923 }
1924 }
1925 }
1926
1927 if iteration > 0 && (new_log_likelihood - log_likelihood).abs() < self.tolerance {
1929 break;
1930 }
1931 log_likelihood = new_log_likelihood;
1932 }
1933
1934 let mut labels = vec![0; n_samples];
1936 for i in 0..n_samples {
1937 let mut max_resp = T::zero();
1938 let mut best_cluster = 0;
1939 for k in 0..self._ncomponents {
1940 if responsibilities[i][k] > max_resp {
1941 max_resp = responsibilities[i][k];
1942 best_cluster = k;
1943 }
1944 }
1945 labels[i] = best_cluster;
1946 }
1947
1948 Ok(GMMResult {
1949 means,
1950 weights,
1951 covariances,
1952 labels,
1953 log_likelihood,
1954 converged: true,
1955 })
1956 }
1957
1958 fn point_to_generic<P>(point: &P) -> Point<T>
1960 where
1961 P: SpatialPoint<T>,
1962 {
1963 let coords: Vec<T> = (0..point.dimension())
1964 .map(|i| point.coordinate(i).unwrap_or(T::zero()))
1965 .collect();
1966 Point::new(coords)
1967 }
1968
1969 fn compute_log_gaussian_probability(
1971 &self,
1972 point: &Point<T>,
1973 mean: &Point<T>,
1974 covariance: &[Vec<T>],
1975 n_features: usize,
1976 ) -> T {
1977 let mut diff = vec![T::zero(); n_features];
1979 for (i, item) in diff.iter_mut().enumerate().take(n_features) {
1980 *item =
1981 point.coordinate(i).unwrap_or(T::zero()) - mean.coordinate(i).unwrap_or(T::zero());
1982 }
1983
1984 let mut det = T::one();
1986 let mut inv_cov = vec![vec![T::zero(); n_features]; n_features];
1987
1988 for i in 0..n_features {
1991 det = det * covariance[i][i];
1992 inv_cov[i][i] = T::one() / covariance[i][i];
1993 }
1994
1995 let mut quadratic_form = T::zero();
1997 for i in 0..n_features {
1998 for j in 0..n_features {
1999 quadratic_form = quadratic_form + diff[i] * inv_cov[i][j] * diff[j];
2000 }
2001 }
2002
2003 let two_pi = T::from(std::f64::consts::TAU)
2005 .unwrap_or(T::from(std::f64::consts::TAU).expect("Operation failed"));
2006 let log_2pi_k = T::from(n_features).expect("Operation failed") * two_pi.ln();
2007 let log_det = det.abs().ln();
2008
2009 let log_prob =
2010 -T::from(0.5).expect("Operation failed") * (log_2pi_k + log_det + quadratic_form);
2011
2012 if Float::is_finite(log_prob) {
2014 log_prob
2015 } else {
2016 T::min_value()
2017 }
2018 }
2019}
2020
2021#[derive(Debug, Clone)]
2023pub struct GMMResult<T: SpatialScalar> {
2024 pub means: Vec<Point<T>>,
2026 pub weights: Vec<T>,
2028 pub covariances: Vec<Vec<Vec<T>>>,
2030 pub labels: Vec<usize>,
2032 pub log_likelihood: T,
2034 pub converged: bool,
2036}
2037
2038#[cfg(test)]
2039#[path = "generic_algorithms_tests.rs"]
2040mod tests;