1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
36use scirs2_core::numeric::{Float, FromPrimitive};
37use std::collections::HashMap;
38use std::fmt::{Debug, Display};
39use std::hash::{Hash, Hasher};
40use std::marker::{Send, Sync};
41
42use crate::error::ClusteringError;
43use scirs2_core::validation::{
44 check_positive, checkarray_finite, clustering::validate_clustering_data,
45 parameters::check_unit_interval,
46};
47use scirs2_spatial::distance::EuclideanDistance;
48use scirs2_spatial::kdtree::KDTree;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum KernelType {
53 Flat,
55 Gaussian,
57}
58
59impl Default for KernelType {
60 fn default() -> Self {
61 KernelType::Flat
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum BandwidthEstimator {
68 KNNQuantile,
70 Silverman,
72 Scott,
74}
75
76impl Default for BandwidthEstimator {
77 fn default() -> Self {
78 BandwidthEstimator::KNNQuantile
79 }
80}
81
82pub struct MeanShiftOptions<T: Float> {
84 pub bandwidth: Option<T>,
87
88 pub seeds: Option<Array2<T>>,
91
92 pub bin_seeding: bool,
94
95 pub min_bin_freq: usize,
97
98 pub cluster_all: bool,
100
101 pub max_iter: usize,
103
104 pub kernel: KernelType,
106
107 pub bandwidth_estimator: BandwidthEstimator,
109}
110
111impl<T: Float> Default for MeanShiftOptions<T> {
112 fn default() -> Self {
113 Self {
114 bandwidth: None,
115 seeds: None,
116 bin_seeding: false,
117 min_bin_freq: 1,
118 cluster_all: true,
119 max_iter: 300,
120 kernel: KernelType::Flat,
121 bandwidth_estimator: BandwidthEstimator::KNNQuantile,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128struct FloatPoint<T: Float>(Vec<T>);
129
130impl<T: Float> PartialEq for FloatPoint<T> {
131 fn eq(&self, other: &Self) -> bool {
132 if self.0.len() != other.0.len() {
133 return false;
134 }
135
136 for (a, b) in self.0.iter().zip(other.0.iter()) {
137 if !a.is_finite() || !b.is_finite() || (*a - *b).abs() > T::epsilon() {
138 return false;
139 }
140 }
141 true
142 }
143}
144
145impl<T: Float> Eq for FloatPoint<T> {}
146
147impl<T: Float> Hash for FloatPoint<T> {
148 fn hash<H: Hasher>(&self, state: &mut H) {
149 for value in &self.0 {
150 let bits = if let Some(bits) = value.to_f64() {
151 (bits * 1e10).round() as i64
152 } else {
153 0
154 };
155 bits.hash(state);
156 }
157 }
158}
159
160pub fn estimate_bandwidth_silverman<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
166 data: &ArrayView2<T>,
167) -> Result<T, ClusteringError> {
168 checkarray_finite(data, "data")?;
169
170 let n = data.nrows();
171 if n < 2 {
172 return Ok(T::from(1.0).ok_or_else(|| {
173 ClusteringError::ComputationError("Failed to convert constant".into())
174 })?);
175 }
176
177 let n_features = data.ncols();
178 let n_f = T::from(n)
179 .ok_or_else(|| ClusteringError::ComputationError("Failed to convert n".into()))?;
180
181 let mut bandwidth_sum = T::zero();
183
184 for col_idx in 0..n_features {
185 let mut values: Vec<T> = (0..n).map(|i| data[[i, col_idx]]).collect();
187 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
188
189 let mean = values.iter().fold(T::zero(), |a, &b| a + b) / n_f;
191 let var = values
192 .iter()
193 .fold(T::zero(), |acc, &v| acc + (v - mean) * (v - mean))
194 / n_f;
195 let std_dev = var.sqrt();
196
197 let q1_idx = n / 4;
199 let q3_idx = (3 * n) / 4;
200 let iqr = values[q3_idx.min(n - 1)] - values[q1_idx];
201 let one_point_three_four = T::from(1.34).ok_or_else(|| {
202 ClusteringError::ComputationError("Failed to convert constant".into())
203 })?;
204 let iqr_scaled = iqr / one_point_three_four;
205
206 let spread = if iqr_scaled > T::zero() && iqr_scaled < std_dev {
208 iqr_scaled
209 } else {
210 std_dev
211 };
212
213 let zero_nine = T::from(0.9).ok_or_else(|| {
215 ClusteringError::ComputationError("Failed to convert constant".into())
216 })?;
217 let exponent = T::from(-0.2).ok_or_else(|| {
218 ClusteringError::ComputationError("Failed to convert constant".into())
219 })?;
220 let n_factor = n_f.powf(exponent);
221
222 let h = zero_nine * spread * n_factor;
223 bandwidth_sum = bandwidth_sum + h;
224 }
225
226 let n_feat_f = T::from(n_features)
227 .ok_or_else(|| ClusteringError::ComputationError("Failed to convert n_features".into()))?;
228 let bandwidth = bandwidth_sum / n_feat_f;
229
230 if bandwidth <= T::zero() {
232 return Ok(T::from(1.0).ok_or_else(|| {
233 ClusteringError::ComputationError("Failed to convert constant".into())
234 })?);
235 }
236
237 Ok(bandwidth)
238}
239
240pub fn estimate_bandwidth_scott<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
246 data: &ArrayView2<T>,
247) -> Result<T, ClusteringError> {
248 checkarray_finite(data, "data")?;
249
250 let n = data.nrows();
251 if n < 2 {
252 return Ok(T::from(1.0).ok_or_else(|| {
253 ClusteringError::ComputationError("Failed to convert constant".into())
254 })?);
255 }
256
257 let n_features = data.ncols();
258 let n_f = T::from(n)
259 .ok_or_else(|| ClusteringError::ComputationError("Failed to convert n".into()))?;
260
261 let d_plus_4 = T::from(n_features as f64 + 4.0)
263 .ok_or_else(|| ClusteringError::ComputationError("Failed to convert dimension".into()))?;
264 let exponent = T::from(-1.0)
265 .ok_or_else(|| ClusteringError::ComputationError("Failed to convert constant".into()))?
266 / d_plus_4;
267 let n_factor = n_f.powf(exponent);
268
269 let mut std_sum = T::zero();
271 for col_idx in 0..n_features {
272 let mean = (0..n)
273 .map(|i| data[[i, col_idx]])
274 .fold(T::zero(), |a, b| a + b)
275 / n_f;
276 let var = (0..n)
277 .map(|i| {
278 let diff = data[[i, col_idx]] - mean;
279 diff * diff
280 })
281 .fold(T::zero(), |a, b| a + b)
282 / n_f;
283 std_sum = std_sum + var.sqrt();
284 }
285
286 let avg_std = std_sum
287 / T::from(n_features).ok_or_else(|| {
288 ClusteringError::ComputationError("Failed to convert n_features".into())
289 })?;
290
291 let bandwidth = n_factor * avg_std;
292
293 if bandwidth <= T::zero() {
294 return Ok(T::from(1.0).ok_or_else(|| {
295 ClusteringError::ComputationError("Failed to convert constant".into())
296 })?);
297 }
298
299 Ok(bandwidth)
300}
301
302pub fn estimate_bandwidth<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
307 data: &ArrayView2<T>,
308 quantile: Option<T>,
309 n_samples: Option<usize>,
310 _random_state: Option<u64>,
311) -> Result<T, ClusteringError> {
312 checkarray_finite(data, "data")?;
313
314 let quantile = quantile
315 .unwrap_or_else(|| T::from(0.3).unwrap_or_else(|| T::from(0.3f64).unwrap_or(T::one())));
316 let _quantile = check_unit_interval(quantile, "quantile", "estimate_bandwidth")?;
317
318 let data = if let Some(n) = n_samples {
320 if n >= data.nrows() {
321 data.to_owned()
322 } else {
323 let mut rng = scirs2_core::random::rng();
324 use scirs2_core::random::seq::SliceRandom;
325 let mut indices: Vec<usize> = (0..data.nrows()).collect();
326 indices.shuffle(&mut rng);
327
328 let indices = &indices[0..n];
329 let mut sampled_data = Array2::zeros((n, data.ncols()));
330 for (i, &idx) in indices.iter().enumerate() {
331 sampled_data.row_mut(i).assign(&data.row(idx));
332 }
333 sampled_data
334 }
335 } else {
336 data.to_owned()
337 };
338
339 let n_neighbors = (T::from(data.nrows()).unwrap_or(T::one()) * quantile)
340 .to_usize()
341 .unwrap_or(1)
342 .max(1)
343 .min(data.nrows().saturating_sub(1));
344
345 let kdtree = KDTree::<_, EuclideanDistance<T>>::new(&data)
347 .map_err(|e| ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e)))?;
348
349 let mut bandwidth_sum = T::zero();
350
351 let batch_size = 500;
352 for i in (0..data.nrows()).step_by(batch_size) {
353 let end = (i + batch_size).min(data.nrows());
354 let batch = data.slice(scirs2_core::ndarray::s![i..end, ..]);
355
356 for row in batch.rows() {
357 let (_, distances) = kdtree.query(&row.to_vec(), n_neighbors + 1).map_err(|e| {
358 ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
359 })?;
360
361 if distances.len() > 1 {
362 let kth_dist = distances
363 .last()
364 .copied()
365 .unwrap_or_else(|| T::from(1.0).unwrap_or(T::one()));
366 bandwidth_sum = bandwidth_sum + kth_dist;
367 } else if !distances.is_empty() {
368 bandwidth_sum = bandwidth_sum + T::from(1.0).unwrap_or(T::one());
369 }
370 }
371 }
372
373 Ok(bandwidth_sum / T::from(data.nrows()).unwrap_or(T::one()))
374}
375
376pub fn get_bin_seeds<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
378 data: &ArrayView2<T>,
379 bin_size: T,
380 min_bin_freq: usize,
381) -> Array2<T> {
382 if bin_size <= T::zero() {
383 return data.to_owned();
384 }
385
386 let mut bin_sizes: HashMap<FloatPoint<T>, usize> = HashMap::new();
387
388 for row in data.rows() {
389 let mut binned_point = Vec::with_capacity(row.len());
390 for &val in row.iter() {
391 binned_point.push((val / bin_size).round() * bin_size);
392 }
393 let point = FloatPoint::<T>(binned_point);
394 *bin_sizes.entry(point).or_insert(0) += 1;
395 }
396
397 let seeds: Vec<Vec<T>> = bin_sizes
398 .into_iter()
399 .filter(|(_, freq)| *freq >= min_bin_freq)
400 .map(|(point, _)| point.0)
401 .collect();
402
403 if seeds.len() == data.nrows() {
404 return data.to_owned();
405 }
406
407 if seeds.is_empty() {
408 Array2::zeros((0, data.ncols()))
409 } else {
410 let mut result = Array2::zeros((seeds.len(), data.ncols()));
411 for (i, seed) in seeds.into_iter().enumerate() {
412 for (j, val) in seed.into_iter().enumerate() {
413 result[[i, j]] = val;
414 }
415 }
416 result
417 }
418}
419
420fn mean_shift_single_seed_flat<
422 T: Float
423 + Display
424 + std::iter::Sum
425 + FromPrimitive
426 + Send
427 + Sync
428 + 'static
429 + scirs2_core::ndarray::ScalarOperand,
430>(
431 seed: ArrayView1<T>,
432 data: &ArrayView2<T>,
433 bandwidth: T,
434 max_iter: usize,
435) -> (Vec<T>, usize, usize) {
436 let stop_thresh = bandwidth * T::from(1e-3).unwrap_or(T::epsilon());
437 let mut my_mean = seed.to_owned();
438 let mut completed_iterations = 0;
439
440 let owned_data = data.to_owned();
441 let kdtree = match KDTree::<_, EuclideanDistance<T>>::new(&owned_data) {
442 Ok(tree) => tree,
443 Err(_) => return (seed.to_vec(), 0, 0),
444 };
445
446 loop {
447 let (indices, _distances) = match kdtree.query_radius(&my_mean.to_vec(), bandwidth) {
448 Ok((idx, distances)) => (idx, distances),
449 Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
450 };
451
452 if indices.is_empty() {
453 break;
454 }
455 let my_old_mean = my_mean.clone();
456
457 my_mean.fill(T::zero());
459 let mut sum = Array1::zeros(my_mean.dim());
460 for &point_idx in &indices {
461 let row_clone = data.row(point_idx).to_owned();
462 for (s, v) in sum.iter_mut().zip(row_clone.iter()) {
463 *s = *s + *v;
464 }
465 }
466 my_mean = sum / T::from(indices.len()).unwrap_or(T::one());
467
468 let mut dist_squared = T::zero();
469 for (a, b) in my_mean.iter().zip(my_old_mean.iter()) {
470 dist_squared = dist_squared + (*a - *b) * (*a - *b);
471 }
472 let dist = dist_squared.sqrt();
473
474 if dist <= stop_thresh || completed_iterations == max_iter {
475 break;
476 }
477
478 completed_iterations += 1;
479 }
480
481 let (final_indices, _) = match kdtree.query_radius(&my_mean.to_vec(), bandwidth) {
482 Ok((idx, distances)) => (idx, distances),
483 Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
484 };
485
486 (my_mean.to_vec(), final_indices.len(), completed_iterations)
487}
488
489fn mean_shift_single_seed_gaussian<
491 T: Float
492 + Display
493 + std::iter::Sum
494 + FromPrimitive
495 + Send
496 + Sync
497 + 'static
498 + scirs2_core::ndarray::ScalarOperand,
499>(
500 seed: ArrayView1<T>,
501 data: &ArrayView2<T>,
502 bandwidth: T,
503 max_iter: usize,
504) -> (Vec<T>, usize, usize) {
505 let stop_thresh = bandwidth * T::from(1e-3).unwrap_or(T::epsilon());
506 let mut my_mean = seed.to_owned();
507 let mut completed_iterations = 0;
508 let bw_sq = bandwidth * bandwidth;
509
510 let search_radius = bandwidth * T::from(3.0).unwrap_or(T::one() + T::one() + T::one());
512
513 let owned_data = data.to_owned();
514 let kdtree = match KDTree::<_, EuclideanDistance<T>>::new(&owned_data) {
515 Ok(tree) => tree,
516 Err(_) => return (seed.to_vec(), 0, 0),
517 };
518
519 loop {
520 let (indices, distances) = match kdtree.query_radius(&my_mean.to_vec(), search_radius) {
521 Ok((idx, distances)) => (idx, distances),
522 Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
523 };
524
525 if indices.is_empty() {
526 break;
527 }
528 let my_old_mean = my_mean.clone();
529
530 let two = T::from(2.0).unwrap_or(T::one() + T::one());
532 let n_features = my_mean.dim();
533 let mut weighted_sum = Array1::zeros(n_features);
534 let mut weight_total = T::zero();
535
536 for (local_idx, &point_idx) in indices.iter().enumerate() {
537 let dist = distances[local_idx];
538 let dist_sq = dist * dist;
539 let weight = (-dist_sq / (two * bw_sq)).exp();
540
541 let row = data.row(point_idx);
542 for (ws, &v) in weighted_sum.iter_mut().zip(row.iter()) {
543 *ws = *ws + v * weight;
544 }
545 weight_total = weight_total + weight;
546 }
547
548 if weight_total > T::zero() {
549 my_mean = weighted_sum / weight_total;
550 }
551
552 let mut dist_squared = T::zero();
553 for (a, b) in my_mean.iter().zip(my_old_mean.iter()) {
554 dist_squared = dist_squared + (*a - *b) * (*a - *b);
555 }
556 let dist = dist_squared.sqrt();
557
558 if dist <= stop_thresh || completed_iterations == max_iter {
559 break;
560 }
561
562 completed_iterations += 1;
563 }
564
565 let (final_indices, _) = match kdtree.query_radius(&my_mean.to_vec(), bandwidth) {
566 Ok((idx, distances)) => (idx, distances),
567 Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
568 };
569
570 (my_mean.to_vec(), final_indices.len(), completed_iterations)
571}
572
573pub fn mean_shift<
584 T: Float
585 + Display
586 + std::iter::Sum
587 + FromPrimitive
588 + Send
589 + Sync
590 + 'static
591 + scirs2_core::ndarray::ScalarOperand
592 + Debug,
593>(
594 data: &ArrayView2<T>,
595 options: MeanShiftOptions<T>,
596) -> Result<(Array2<T>, Array1<i32>), ClusteringError> {
597 let mut model = MeanShift::new(options);
598 let model = model.fit(data)?;
599 Ok((
600 model.cluster_centers().to_owned(),
601 model.labels().to_owned(),
602 ))
603}
604
605pub struct MeanShift<T: Float> {
607 options: MeanShiftOptions<T>,
608 cluster_centers_: Option<Array2<T>>,
609 labels_: Option<Array1<i32>>,
610 n_iter_: usize,
611 bandwidth_used_: Option<T>,
612}
613
614impl<
615 T: Float
616 + Display
617 + std::iter::Sum
618 + FromPrimitive
619 + Send
620 + Sync
621 + 'static
622 + scirs2_core::ndarray::ScalarOperand
623 + Debug,
624 > MeanShift<T>
625{
626 pub fn new(options: MeanShiftOptions<T>) -> Self {
628 Self {
629 options,
630 cluster_centers_: None,
631 labels_: None,
632 n_iter_: 0,
633 bandwidth_used_: None,
634 }
635 }
636
637 pub fn fit(&mut self, data: &ArrayView2<T>) -> Result<&mut Self, ClusteringError> {
639 let config = crate::input_validation::ValidationConfig::default();
640 crate::input_validation::validate_clustering_data(data.view(), &config)?;
641
642 let (n_samples, n_features) = data.dim();
643
644 let bandwidth = match self.options.bandwidth {
646 Some(bw) => check_positive(bw, "bandwidth")?,
647 None => match self.options.bandwidth_estimator {
648 BandwidthEstimator::Silverman => estimate_bandwidth_silverman(data)?,
649 BandwidthEstimator::Scott => estimate_bandwidth_scott(data)?,
650 BandwidthEstimator::KNNQuantile => {
651 estimate_bandwidth(data, Some(T::from(0.3).unwrap_or(T::one())), None, None)?
652 }
653 },
654 };
655 self.bandwidth_used_ = Some(bandwidth);
656
657 let seeds = match &self.options.seeds {
659 Some(s) => s.clone(),
660 None => {
661 if self.options.bin_seeding {
662 get_bin_seeds(data, bandwidth, self.options.min_bin_freq)
663 } else {
664 data.to_owned()
665 }
666 }
667 };
668
669 if seeds.is_empty() {
670 return Err(ClusteringError::ComputationError(
671 "No seeds provided and bin seeding produced no seeds".to_string(),
672 ));
673 }
674
675 let kernel = self.options.kernel;
677 let max_iter = self.options.max_iter;
678
679 let seed_results: Vec<_> = seeds
680 .axis_iter(Axis(0))
681 .map(|seed| match kernel {
682 KernelType::Flat => mean_shift_single_seed_flat(seed, data, bandwidth, max_iter),
683 KernelType::Gaussian => {
684 mean_shift_single_seed_gaussian(seed, data, bandwidth, max_iter)
685 }
686 })
687 .collect();
688
689 let mut center_intensity_dict: HashMap<FloatPoint<T>, usize> = HashMap::new();
691 for (center, size, iterations) in seed_results {
692 if size > 0 {
693 center_intensity_dict.insert(FloatPoint(center), size);
694 }
695 self.n_iter_ = self.n_iter_.max(iterations);
696 }
697
698 if center_intensity_dict.is_empty() {
699 return Err(ClusteringError::ComputationError(format!(
700 "No point was within bandwidth={} of any seed. \
701 Try a different seeding strategy or increase the bandwidth.",
702 bandwidth
703 )));
704 }
705
706 let mut sorted_by_intensity: Vec<_> = center_intensity_dict.into_iter().collect();
708 sorted_by_intensity.sort_by(|a, b| {
709 b.1.cmp(&a.1).then_with(|| {
710 a.0 .0
711 .iter()
712 .zip(b.0 .0.iter())
713 .find_map(|(a_val, b_val)| a_val.partial_cmp(b_val))
714 .unwrap_or(std::cmp::Ordering::Equal)
715 })
716 });
717
718 if !self.options.cluster_all {
719 let min_density_threshold = 2;
720 sorted_by_intensity.retain(|(_, intensity)| *intensity >= min_density_threshold);
721
722 if sorted_by_intensity.is_empty() {
723 return Err(ClusteringError::ComputationError(
724 "No clusters found with sufficient density.".to_string(),
725 ));
726 }
727 }
728
729 let mut sorted_centers = Array2::zeros((sorted_by_intensity.len(), n_features));
731 for (i, center_) in sorted_by_intensity.iter().enumerate() {
732 for (j, &val) in center_.0 .0.iter().enumerate() {
733 sorted_centers[[i, j]] = val;
734 }
735 }
736
737 let mut unique = vec![true; sorted_centers.nrows()];
739
740 let kdtree = KDTree::<_, EuclideanDistance<T>>::new(&sorted_centers).map_err(|e| {
741 ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e))
742 })?;
743
744 let merge_threshold = bandwidth * T::from(0.1).unwrap_or(T::epsilon());
745
746 for i in 0..sorted_centers.nrows() {
747 if unique[i] {
748 let (indices_, _) = kdtree
749 .query_radius(&sorted_centers.row(i).to_vec(), merge_threshold)
750 .map_err(|e| {
751 ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
752 })?;
753
754 for &idx in indices_.iter() {
755 if idx != i {
756 unique[idx] = false;
757 }
758 }
759 }
760 }
761
762 let unique_indices: Vec<_> = unique
763 .iter()
764 .enumerate()
765 .filter(|&(_, &is_unique)| is_unique)
766 .map(|(i_, _)| i_)
767 .collect();
768
769 let mut cluster_centers = Array2::zeros((unique_indices.len(), n_features));
770 for (i, &idx) in unique_indices.iter().enumerate() {
771 cluster_centers.row_mut(i).assign(&sorted_centers.row(idx));
772 }
773
774 let centers_kdtree =
776 KDTree::<_, EuclideanDistance<T>>::new(&cluster_centers).map_err(|e| {
777 ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e))
778 })?;
779
780 let mut labels = Array1::zeros(n_samples);
781
782 let batch_size = 1000;
783 for i in (0..n_samples).step_by(batch_size) {
784 let end = (i + batch_size).min(n_samples);
785 let batch = data.slice(scirs2_core::ndarray::s![i..end, ..]);
786
787 for (row_idx, row) in batch.rows().into_iter().enumerate() {
788 let point_idx = i + row_idx;
789
790 let (indices, distances) = centers_kdtree.query(&row.to_vec(), 1).map_err(|e| {
791 ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
792 })?;
793
794 if !indices.is_empty() {
795 let idx = indices[0];
796 let distance = T::from(distances[0]).unwrap_or(T::zero());
797
798 if self.options.cluster_all || (distance <= bandwidth) {
799 labels[point_idx] =
800 T::to_i32(&T::from(idx).unwrap_or(T::zero())).unwrap_or(0);
801 } else {
802 labels[point_idx] = -1;
803 }
804 } else {
805 labels[point_idx] = -1;
806 }
807 }
808 }
809
810 self.cluster_centers_ = Some(cluster_centers);
811 self.labels_ = Some(labels);
812
813 Ok(self)
814 }
815
816 pub fn cluster_centers(&self) -> &Array2<T> {
818 self.cluster_centers_
819 .as_ref()
820 .expect("Model has not been fitted yet")
821 }
822
823 pub fn labels(&self) -> &Array1<i32> {
825 self.labels_
826 .as_ref()
827 .expect("Model has not been fitted yet")
828 }
829
830 pub fn n_iter(&self) -> usize {
832 self.n_iter_
833 }
834
835 pub fn bandwidth_used(&self) -> Option<T> {
837 self.bandwidth_used_
838 }
839
840 pub fn predict(&self, data: &ArrayView2<T>) -> Result<Array1<i32>, ClusteringError> {
842 let centers = self.cluster_centers_.as_ref().ok_or_else(|| {
843 ClusteringError::InvalidState("Model has not been fitted yet".to_string())
844 })?;
845
846 checkarray_finite(data, "prediction data")?;
847
848 let n_samples = data.nrows();
849 let mut labels = Array1::zeros(n_samples);
850
851 let kdtree = KDTree::<_, EuclideanDistance<T>>::new(centers).map_err(|e| {
852 ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e))
853 })?;
854
855 let batch_size = 1000;
856 for i in (0..n_samples).step_by(batch_size) {
857 let end = (i + batch_size).min(n_samples);
858 let batch = data.slice(scirs2_core::ndarray::s![i..end, ..]);
859
860 for (row_idx, row) in batch.rows().into_iter().enumerate() {
861 let (indices_, _distances) = kdtree.query(&row.to_vec(), 1).map_err(|e| {
862 ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
863 })?;
864
865 if !indices_.is_empty() {
866 labels[i + row_idx] =
867 T::to_i32(&T::from(indices_[0]).unwrap_or(T::zero())).unwrap_or(0);
868 } else {
869 labels[i + row_idx] = -1;
870 }
871 }
872 }
873
874 Ok(labels)
875 }
876}
877
878#[cfg(test)]
879mod tests {
880 use super::*;
881 use scirs2_core::ndarray::{array, Array2};
882 use std::collections::HashSet;
883
884 fn make_test_data() -> Array2<f64> {
885 array![
886 [1.0, 1.0],
887 [2.0, 1.0],
888 [1.0, 0.0],
889 [4.0, 7.0],
890 [3.0, 5.0],
891 [3.0, 6.0]
892 ]
893 }
894
895 #[test]
896 fn test_estimate_bandwidth() {
897 let data = make_test_data();
898 let bandwidth = estimate_bandwidth(&data.view(), Some(0.4), None, None)
899 .expect("Bandwidth estimation should succeed");
900
901 assert!(
902 bandwidth > 0.0,
903 "Bandwidth should be positive, got: {}",
904 bandwidth
905 );
906 assert!(
907 bandwidth < 20.0,
908 "Bandwidth should be reasonable, got: {}",
909 bandwidth
910 );
911 }
912
913 #[test]
914 fn test_estimate_bandwidth_silverman() {
915 let data = make_test_data();
916 let bandwidth = estimate_bandwidth_silverman(&data.view())
917 .expect("Silverman estimation should succeed");
918
919 assert!(bandwidth > 0.0, "Silverman bandwidth should be positive");
920 assert!(bandwidth < 20.0, "Silverman bandwidth should be reasonable");
921 }
922
923 #[test]
924 fn test_estimate_bandwidth_scott() {
925 let data = make_test_data();
926 let bandwidth =
927 estimate_bandwidth_scott(&data.view()).expect("Scott estimation should succeed");
928
929 assert!(bandwidth > 0.0, "Scott bandwidth should be positive");
930 assert!(bandwidth < 20.0, "Scott bandwidth should be reasonable");
931 }
932
933 #[test]
934 fn test_estimate_bandwidth_small_sample() {
935 let data = array![[1.0, 1.0]];
936 let bandwidth = estimate_bandwidth(&data.view(), Some(0.3), None, None)
937 .expect("Should work for single sample");
938 assert!(bandwidth > 0.0);
939 assert_eq!(bandwidth, 1.0);
940 }
941
942 #[test]
943 fn test_get_bin_seeds() {
944 let data = array![
945 [1.0, 1.0],
946 [1.4, 1.4],
947 [1.8, 1.2],
948 [2.0, 1.0],
949 [2.1, 1.1],
950 [0.0, 0.0]
951 ];
952
953 let bin_seeds = get_bin_seeds(&data.view(), 1.0, 1);
954 assert_eq!(bin_seeds.nrows(), 3);
955
956 let bin_seeds = get_bin_seeds(&data.view(), 1.0, 2);
957 assert_eq!(bin_seeds.nrows(), 2);
958
959 let bin_seeds = get_bin_seeds(&data.view(), 0.01, 1);
960 assert_eq!(bin_seeds.nrows(), data.nrows());
961 }
962
963 #[test]
964 fn test_mean_shift_flat_kernel() {
965 let data = make_test_data();
966
967 let options = MeanShiftOptions {
968 bandwidth: Some(2.0),
969 kernel: KernelType::Flat,
970 ..Default::default()
971 };
972
973 let (centers, labels) =
974 mean_shift(&data.view(), options).expect("Mean shift with flat kernel should succeed");
975
976 assert!(centers.nrows() >= 1, "Should find at least 1 cluster");
977 assert!(centers.nrows() <= 3, "Should find at most 3 clusters");
978 assert!(
979 labels.iter().all(|&l| l >= 0),
980 "All labels should be non-negative"
981 );
982 }
983
984 #[test]
985 fn test_mean_shift_gaussian_kernel() {
986 let data = make_test_data();
987
988 let options = MeanShiftOptions {
989 bandwidth: Some(2.0),
990 kernel: KernelType::Gaussian,
991 ..Default::default()
992 };
993
994 let (centers, labels) = mean_shift(&data.view(), options)
995 .expect("Mean shift with Gaussian kernel should succeed");
996
997 assert!(centers.nrows() >= 1, "Should find at least 1 cluster");
998 assert!(
999 labels.iter().all(|&l| l >= 0),
1000 "All labels should be non-negative"
1001 );
1002 }
1003
1004 #[test]
1005 fn test_mean_shift_bin_seeding() {
1006 let data = make_test_data();
1007
1008 let options = MeanShiftOptions {
1009 bandwidth: Some(2.0),
1010 bin_seeding: true,
1011 ..Default::default()
1012 };
1013
1014 let (centers, labels) =
1015 mean_shift(&data.view(), options).expect("Mean shift with bin seeding should succeed");
1016
1017 assert!(centers.nrows() >= 1);
1018 assert!(centers.nrows() <= 3);
1019 assert!(labels.iter().all(|&l| l >= 0));
1020 }
1021
1022 #[test]
1023 fn test_mean_shift_no_cluster_all() {
1024 let data = array![
1025 [1.0, 1.0],
1026 [2.0, 1.0],
1027 [1.0, 0.0],
1028 [4.0, 7.0],
1029 [3.0, 5.0],
1030 [3.0, 6.0],
1031 [10.0, 10.0]
1032 ];
1033
1034 let options = MeanShiftOptions {
1035 bandwidth: Some(2.0),
1036 cluster_all: false,
1037 ..Default::default()
1038 };
1039
1040 let (_centers, labels) =
1041 mean_shift(&data.view(), options).expect("Mean shift should succeed");
1042
1043 assert!(labels.iter().any(|&l| l == -1));
1044 }
1045
1046 #[test]
1047 fn test_mean_shift_max_iter() {
1048 let data = make_test_data();
1049
1050 let options = MeanShiftOptions {
1051 bandwidth: Some(2.0),
1052 max_iter: 1,
1053 ..Default::default()
1054 };
1055
1056 let mut model = MeanShift::new(options);
1057 model.fit(&data.view()).expect("Should fit");
1058
1059 assert_eq!(model.n_iter(), 1);
1060 }
1061
1062 #[test]
1063 fn test_mean_shift_predict() {
1064 let data = make_test_data();
1065
1066 let options = MeanShiftOptions {
1067 bandwidth: Some(2.0),
1068 ..Default::default()
1069 };
1070
1071 let mut model = MeanShift::new(options);
1072 model.fit(&data.view()).expect("Should fit");
1073
1074 let predicted_labels = model.predict(&data.view()).expect("Predict should succeed");
1075 assert_eq!(predicted_labels, model.labels().clone());
1076 }
1077
1078 #[test]
1079 fn test_mean_shift_silverman_bandwidth() {
1080 let data = make_test_data();
1081
1082 let options = MeanShiftOptions {
1083 bandwidth: None,
1084 bandwidth_estimator: BandwidthEstimator::Silverman,
1085 ..Default::default()
1086 };
1087
1088 let mut model = MeanShift::new(options);
1089 model
1090 .fit(&data.view())
1091 .expect("Should fit with Silverman bandwidth");
1092
1093 assert!(model.bandwidth_used().is_some());
1094 assert!(
1095 model.bandwidth_used().unwrap_or(0.0) > 0.0,
1096 "Silverman bandwidth should be positive"
1097 );
1098 }
1099
1100 #[test]
1101 fn test_mean_shift_scott_bandwidth() {
1102 let data = make_test_data();
1103
1104 let options = MeanShiftOptions {
1105 bandwidth: None,
1106 bandwidth_estimator: BandwidthEstimator::Scott,
1107 ..Default::default()
1108 };
1109
1110 let mut model = MeanShift::new(options);
1111 model
1112 .fit(&data.view())
1113 .expect("Should fit with Scott bandwidth");
1114
1115 assert!(model.bandwidth_used().is_some());
1116 assert!(
1117 model.bandwidth_used().unwrap_or(0.0) > 0.0,
1118 "Scott bandwidth should be positive"
1119 );
1120 }
1121
1122 #[test]
1123 fn test_mean_shift_large_dataset() {
1124 let mut data = Array2::zeros((20, 2));
1125
1126 for i in 0..10 {
1127 data[[i, 0]] = 1.0 + 0.05 * (i as f64);
1128 data[[i, 1]] = 1.0 + 0.05 * (i as f64);
1129 }
1130
1131 for i in 10..20 {
1132 data[[i, 0]] = 8.0 + 0.05 * ((i - 10) as f64);
1133 data[[i, 1]] = 8.0 + 0.05 * ((i - 10) as f64);
1134 }
1135
1136 let options = MeanShiftOptions {
1137 bandwidth: Some(1.5),
1138 bin_seeding: true,
1139 ..Default::default()
1140 };
1141
1142 let (centers, labels) =
1143 mean_shift(&data.view(), options).expect("Should handle larger dataset");
1144
1145 assert!(centers.nrows() >= 1);
1146 assert!(centers.nrows() <= 3);
1147
1148 let unique_labels: HashSet<_> = labels.iter().cloned().collect();
1149 assert!(!unique_labels.is_empty());
1150 assert!(unique_labels.len() <= centers.nrows());
1151 }
1152}