1use std::f64::consts::TAU;
16
17use crate::error::{ClusteringError, Result};
18
19struct Lcg {
25 state: u64,
26}
27
28impl Lcg {
29 fn new(seed: u64) -> Self {
30 let state = if seed == 0 { 6364136223846793005 } else { seed };
31 Self { state }
32 }
33
34 fn next_f64(&mut self) -> f64 {
36 self.state = self
38 .state
39 .wrapping_mul(6364136223846793005)
40 .wrapping_add(1442695040888963407);
41 let bits = (self.state >> 11) as f64;
42 bits / (1u64 << 53) as f64
43 }
44
45 fn next_range_usize(&mut self, low: usize, high: usize) -> usize {
47 if low >= high {
48 return low;
49 }
50 let span = (high - low) as f64;
51 low + (self.next_f64() * span) as usize
52 }
53
54 fn next_normal(&mut self) -> f64 {
56 let u1 = self.next_f64().max(1e-15);
57 let u2 = self.next_f64();
58 (-2.0 * u1.ln()).sqrt() * (TAU * u2).cos()
59 }
60}
61
62#[derive(Clone, Debug)]
68pub enum KernelType {
69 Linear,
71 Polynomial { degree: u32, coef0: f64, gamma: f64 },
73 Rbf { gamma: f64 },
75 Sigmoid { coef0: f64, gamma: f64 },
77}
78
79impl KernelType {
80 pub fn compute(&self, x: &[f64], y: &[f64]) -> f64 {
82 debug_assert_eq!(x.len(), y.len(), "kernel vectors must have same dimension");
83 match self {
84 KernelType::Linear => dot(x, y),
85 KernelType::Polynomial {
86 degree,
87 coef0,
88 gamma,
89 } => (gamma * dot(x, y) + coef0).powi(*degree as i32),
90 KernelType::Rbf { gamma } => {
91 let sq = sq_dist(x, y);
92 (-gamma * sq).exp()
93 }
94 KernelType::Sigmoid { coef0, gamma } => (gamma * dot(x, y) + coef0).tanh(),
95 }
96 }
97}
98
99fn dot(a: &[f64], b: &[f64]) -> f64 {
100 a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
101}
102
103fn sq_dist(a: &[f64], b: &[f64]) -> f64 {
104 a.iter()
105 .zip(b.iter())
106 .map(|(ai, bi)| (ai - bi).powi(2))
107 .sum()
108}
109
110fn build_kernel_matrix(data: &[Vec<f64>], kernel: &KernelType) -> Vec<Vec<f64>> {
116 let n = data.len();
117 let mut k = vec![vec![0.0f64; n]; n];
118 for i in 0..n {
119 k[i][i] = kernel.compute(&data[i], &data[i]);
120 for j in (i + 1)..n {
121 let v = kernel.compute(&data[i], &data[j]);
122 k[i][j] = v;
123 k[j][i] = v;
124 }
125 }
126 k
127}
128
129fn kernel_kmeans_objective(k_mat: &[Vec<f64>], labels: &[usize], n_clusters: usize) -> f64 {
134 let _n = labels.len();
135 let mut members: Vec<Vec<usize>> = vec![Vec::new(); n_clusters];
137 for (i, &l) in labels.iter().enumerate() {
138 members[l].push(i);
139 }
140 let mut total = 0.0f64;
141 for (i, &l) in labels.iter().enumerate() {
142 let cl = &members[l];
143 let sz = cl.len() as f64;
144 if sz == 0.0 {
145 continue;
146 }
147 let kii = k_mat[i][i];
149 let cross: f64 = cl.iter().map(|&j| k_mat[i][j]).sum::<f64>();
151 let inner: f64 = cl
153 .iter()
154 .flat_map(|&j| cl.iter().map(move |&kk| k_mat[j][kk]))
155 .sum::<f64>();
156 total += kii - 2.0 * cross / sz + inner / (sz * sz);
157 }
158 total
159}
160
161fn kernel_kmeans_assign(k_mat: &[Vec<f64>], labels: &[usize], n_clusters: usize) -> Vec<usize> {
163 let n = labels.len();
164 let mut members: Vec<Vec<usize>> = vec![Vec::new(); n_clusters];
166 for (i, &l) in labels.iter().enumerate() {
167 members[l].push(i);
168 }
169 let inner: Vec<f64> = (0..n_clusters)
171 .map(|l| {
172 let cl = &members[l];
173 let sz = cl.len() as f64;
174 if sz == 0.0 {
175 return f64::INFINITY;
176 }
177 let s: f64 = cl
178 .iter()
179 .flat_map(|&j| cl.iter().map(move |&kk| k_mat[j][kk]))
180 .sum();
181 s / (sz * sz)
182 })
183 .collect();
184
185 let mut new_labels = vec![0usize; n];
187 for i in 0..n {
188 let mut best_l = 0;
189 let mut best_dist = f64::INFINITY;
190 for l in 0..n_clusters {
191 let cl = &members[l];
192 let sz = cl.len() as f64;
193 if sz == 0.0 {
194 continue;
195 }
196 let cross: f64 = cl.iter().map(|&j| k_mat[i][j]).sum::<f64>();
197 let dist = k_mat[i][i] - 2.0 * cross / sz + inner[l];
198 if dist < best_dist {
199 best_dist = dist;
200 best_l = l;
201 }
202 }
203 new_labels[i] = best_l;
204 }
205 new_labels
206}
207
208pub fn kernel_kmeans(
227 data: &[Vec<f64>],
228 n_clusters: usize,
229 kernel: KernelType,
230 max_iter: usize,
231 n_init: usize,
232 seed: u64,
233) -> Result<(Vec<usize>, f64)> {
234 if data.is_empty() {
235 return Err(ClusteringError::InvalidInput(
236 "data must not be empty".into(),
237 ));
238 }
239 if n_clusters == 0 {
240 return Err(ClusteringError::InvalidInput(
241 "n_clusters must be >= 1".into(),
242 ));
243 }
244 if n_clusters > data.len() {
245 return Err(ClusteringError::InvalidInput(format!(
246 "n_clusters ({}) > n_samples ({})",
247 n_clusters,
248 data.len()
249 )));
250 }
251 let n = data.len();
252 let k_mat = build_kernel_matrix(data, &kernel);
253 let _rng = Lcg::new(seed); let mut best_labels = vec![0usize; n];
256 let mut best_obj = f64::INFINITY;
257
258 for run in 0..n_init.max(1) {
259 let run_seed = seed.wrapping_add(run as u64).wrapping_add(1);
262 let mut run_rng = Lcg::new(run_seed);
263
264 let mut labels = vec![0usize; n];
266 let mut idx: Vec<usize> = (0..n).collect();
268 for i in 0..n_clusters {
270 let j = run_rng.next_range_usize(i, n);
271 idx.swap(i, j);
272 labels[idx[i]] = i;
273 }
274 for i in n_clusters..n {
276 labels[i] = run_rng.next_range_usize(0, n_clusters);
277 }
278
279 for _ in 0..max_iter {
281 let new_labels = kernel_kmeans_assign(&k_mat, &labels, n_clusters);
282 let mut counts = vec![0usize; n_clusters];
284 for &l in &new_labels {
285 counts[l] += 1;
286 }
287 let empty = counts.iter().any(|&c| c == 0);
288 if empty {
289 break;
291 }
292 if new_labels == labels {
293 break;
294 }
295 labels = new_labels;
296 }
297
298 let obj = kernel_kmeans_objective(&k_mat, &labels, n_clusters);
299 if obj < best_obj {
300 best_obj = obj;
301 best_labels = labels;
302 }
303 }
305
306 Ok((best_labels, best_obj))
307}
308
309pub fn trimmed_kmeans(
332 data: &[Vec<f64>],
333 n_clusters: usize,
334 trim_ratio: f64,
335 max_iter: usize,
336 seed: u64,
337) -> Result<(Vec<Option<usize>>, Vec<Vec<f64>>)> {
338 if data.is_empty() {
339 return Err(ClusteringError::InvalidInput(
340 "data must not be empty".into(),
341 ));
342 }
343 if n_clusters == 0 {
344 return Err(ClusteringError::InvalidInput(
345 "n_clusters must be >= 1".into(),
346 ));
347 }
348 if !(0.0..0.5).contains(&trim_ratio) {
349 return Err(ClusteringError::InvalidInput(
350 "trim_ratio must be in [0, 0.5)".into(),
351 ));
352 }
353 let n = data.len();
354 let d = data[0].len();
355 if n_clusters > n {
356 return Err(ClusteringError::InvalidInput(format!(
357 "n_clusters ({}) > n_samples ({})",
358 n_clusters, n
359 )));
360 }
361
362 let n_trim = (n as f64 * trim_ratio).floor() as usize;
363 let n_active = n - n_trim;
364 if n_active < n_clusters {
365 return Err(ClusteringError::InvalidInput(
366 "After trimming, too few points remain for the requested n_clusters".into(),
367 ));
368 }
369
370 let mut rng = Lcg::new(seed);
371
372 let mut centroids = kmeans_plus_plus_init(data, n_clusters, &mut rng);
374
375 let mut labels = vec![None::<usize>; n];
376
377 for _iter in 0..max_iter {
378 let mut dists: Vec<(usize, f64)> = (0..n)
380 .map(|i| {
381 let (cl, dist) = nearest_centroid(&data[i], ¢roids);
382 (cl, dist)
383 })
384 .collect();
385
386 let mut order: Vec<usize> = (0..n).collect();
388 order.sort_by(|&a, &b| {
389 dists[b]
390 .1
391 .partial_cmp(&dists[a].1)
392 .unwrap_or(std::cmp::Ordering::Equal)
393 });
394 let trimmed_set: std::collections::HashSet<usize> =
395 order[..n_trim].iter().cloned().collect();
396
397 for i in 0..n {
399 if trimmed_set.contains(&i) {
400 labels[i] = None;
401 } else {
402 labels[i] = Some(dists[i].0);
403 }
404 }
405
406 let mut new_centroids = vec![vec![0.0f64; d]; n_clusters];
408 let mut counts = vec![0usize; n_clusters];
409 for (i, lbl) in labels.iter().enumerate() {
410 if let Some(l) = lbl {
411 for (feat, &v) in new_centroids[*l].iter_mut().zip(data[i].iter()) {
412 *feat += v;
413 }
414 counts[*l] += 1;
415 }
416 }
417 let mut changed = false;
418 for l in 0..n_clusters {
419 if counts[l] > 0 {
420 let old = ¢roids[l];
421 let new_c: Vec<f64> = new_centroids[l]
422 .iter()
423 .map(|&s| s / counts[l] as f64)
424 .collect();
425 let diff: f64 = old
426 .iter()
427 .zip(new_c.iter())
428 .map(|(a, b)| (a - b).powi(2))
429 .sum::<f64>()
430 .sqrt();
431 if diff > 1e-10 {
432 changed = true;
433 }
434 centroids[l] = new_c;
435 }
436 }
437 if !changed {
438 break;
439 }
440 }
441
442 Ok((labels, centroids))
443}
444
445fn nearest_centroid(point: &[f64], centroids: &[Vec<f64>]) -> (usize, f64) {
446 let mut best_c = 0;
447 let mut best_d = f64::INFINITY;
448 for (i, c) in centroids.iter().enumerate() {
449 let d: f64 = point
450 .iter()
451 .zip(c.iter())
452 .map(|(a, b)| (a - b).powi(2))
453 .sum();
454 if d < best_d {
455 best_d = d;
456 best_c = i;
457 }
458 }
459 (best_c, best_d)
460}
461
462fn kmeans_plus_plus_init(data: &[Vec<f64>], k: usize, rng: &mut Lcg) -> Vec<Vec<f64>> {
463 let n = data.len();
464 let first = rng.next_range_usize(0, n);
465 let mut centroids = vec![data[first].clone()];
466 for _ in 1..k {
467 let dists: Vec<f64> = data
468 .iter()
469 .map(|x| {
470 centroids
471 .iter()
472 .map(|c| sq_dist(x, c))
473 .fold(f64::INFINITY, f64::min)
474 })
475 .collect();
476 let total: f64 = dists.iter().sum();
477 let target = rng.next_f64() * total;
478 let mut cumsum = 0.0;
479 let mut chosen = n - 1;
480 for (i, &d) in dists.iter().enumerate() {
481 cumsum += d;
482 if cumsum >= target {
483 chosen = i;
484 break;
485 }
486 }
487 centroids.push(data[chosen].clone());
488 }
489 centroids
490}
491
492#[derive(Debug, Clone)]
505pub struct DpMixture {
506 pub alpha: f64,
508 pub n_components: usize,
510 pub weights: Vec<f64>,
512 pub means: Vec<Vec<f64>>,
514 pub concentrations: Vec<f64>,
516}
517
518impl DpMixture {
519 pub fn new(alpha: f64) -> Self {
521 Self {
522 alpha,
523 n_components: 0,
524 weights: Vec::new(),
525 means: Vec::new(),
526 concentrations: Vec::new(),
527 }
528 }
529
530 pub fn fit(&mut self, data: &[Vec<f64>], max_iter: usize, seed: u64) -> Vec<usize> {
535 if data.is_empty() {
536 self.n_components = 0;
537 return Vec::new();
538 }
539 let n = data.len();
540 let d = data[0].len();
541 let mut rng = Lcg::new(seed);
542
543 let prior_mean = vec![0.0f64; d];
545 let prior_kappa = 1.0f64; let lambda = 1.0f64; let mut assignments: Vec<usize> = (0..n).collect();
550 let mut component_members: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
552
553 for _iter in 0..max_iter {
554 for i in 0..n {
555 let xi = &data[i];
556 let current_k = assignments[i];
557
558 component_members[current_k].retain(|&m| m != i);
560
561 let alive: Vec<usize> = (0..component_members.len())
563 .filter(|&k| !component_members[k].is_empty())
564 .collect();
565 let new_members: Vec<Vec<usize>> = alive
567 .iter()
568 .map(|&k| component_members[k].clone())
569 .collect();
570 for j in 0..n {
572 if j == i {
573 continue;
574 }
575 let old_k = assignments[j];
576 if let Some(pos) = alive.iter().position(|&k| k == old_k) {
577 assignments[j] = pos;
578 }
579 }
580 component_members = new_members;
581
582 let k_live = component_members.len();
583
584 let mut log_probs: Vec<f64> = Vec::with_capacity(k_live + 1);
586 for k in 0..k_live {
587 let members = &component_members[k];
588 let n_k = members.len() as f64;
589 let kappa_n = prior_kappa + n_k;
591 let mu_n: Vec<f64> = {
592 let mut s = prior_mean.clone();
593 for &m in members.iter() {
594 for (f, &v) in s.iter_mut().zip(data[m].iter()) {
595 *f += v;
596 }
597 }
598 s.iter().map(|&v| v / (prior_kappa + n_k)).collect()
599 };
600 let pred_var = (kappa_n + 1.0) / (kappa_n * lambda);
601 let log_lik: f64 = xi
603 .iter()
604 .zip(mu_n.iter())
605 .map(|(&xf, &mf)| {
606 let z = (xf - mf).powi(2);
607 -0.5 * (z / pred_var + (TAU * pred_var).ln())
608 })
609 .sum();
610 let log_prior = (n_k / (n as f64 - 1.0 + self.alpha)).ln();
611 log_probs.push(log_prior + log_lik);
612 }
613
614 let pred_var_new = (prior_kappa + 1.0) / (prior_kappa * lambda);
616 let log_lik_new: f64 = xi
617 .iter()
618 .zip(prior_mean.iter())
619 .map(|(&xf, &mf)| {
620 let z = (xf - mf).powi(2);
621 -0.5 * (z / pred_var_new + (TAU * pred_var_new).ln())
622 })
623 .sum();
624 let log_prior_new = (self.alpha / (n as f64 - 1.0 + self.alpha)).ln();
625 log_probs.push(log_prior_new + log_lik_new);
626
627 let max_lp = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
629 let probs: Vec<f64> = log_probs.iter().map(|&lp| (lp - max_lp).exp()).collect();
630 let total: f64 = probs.iter().sum();
631 let u = rng.next_f64() * total;
632 let mut cumsum = 0.0;
633 let mut chosen_k = probs.len() - 1;
634 for (idx, &p) in probs.iter().enumerate() {
635 cumsum += p;
636 if cumsum >= u {
637 chosen_k = idx;
638 break;
639 }
640 }
641
642 if chosen_k == k_live {
643 component_members.push(vec![i]);
645 assignments[i] = k_live;
646 } else {
647 component_members[chosen_k].push(i);
648 assignments[i] = chosen_k;
649 }
650 }
651 }
652
653 let k_final = component_members.len();
655 self.n_components = k_final;
656 self.weights = component_members
657 .iter()
658 .map(|m| m.len() as f64 / n as f64)
659 .collect();
660 self.means = component_members
661 .iter()
662 .map(|members| {
663 let mut mu = vec![0.0f64; d];
664 for &m in members.iter() {
665 for (f, &v) in mu.iter_mut().zip(data[m].iter()) {
666 *f += v;
667 }
668 }
669 mu.iter().map(|&v| v / members.len() as f64).collect()
670 })
671 .collect();
672 self.concentrations = vec![lambda; k_final];
673
674 assignments
675 }
676
677 pub fn n_clusters(&self) -> usize {
679 self.n_components
680 }
681}
682
683pub fn fuzzy_cmeans(
707 data: &[Vec<f64>],
708 n_clusters: usize,
709 fuzziness: f64,
710 max_iter: usize,
711 tol: f64,
712 seed: u64,
713) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>)> {
714 if data.is_empty() {
715 return Err(ClusteringError::InvalidInput(
716 "data must not be empty".into(),
717 ));
718 }
719 if n_clusters == 0 {
720 return Err(ClusteringError::InvalidInput(
721 "n_clusters must be >= 1".into(),
722 ));
723 }
724 if fuzziness <= 1.0 {
725 return Err(ClusteringError::InvalidInput(
726 "fuzziness (m) must be > 1.0".into(),
727 ));
728 }
729 if n_clusters > data.len() {
730 return Err(ClusteringError::InvalidInput(format!(
731 "n_clusters ({}) > n_samples ({})",
732 n_clusters,
733 data.len()
734 )));
735 }
736
737 let n = data.len();
738 let d = data[0].len();
739 let m = fuzziness;
740 let exp = 2.0 / (m - 1.0);
741
742 let mut rng = Lcg::new(seed);
744 let mut u: Vec<Vec<f64>> = (0..n)
745 .map(|_| {
746 let raw: Vec<f64> = (0..n_clusters).map(|_| rng.next_f64() + 1e-12).collect();
747 let s: f64 = raw.iter().sum();
748 raw.iter().map(|&v| v / s).collect()
749 })
750 .collect();
751
752 let mut centroids = compute_fuzzy_centroids(data, &u, n_clusters, m, d);
753
754 for _iter in 0..max_iter {
755 let mut new_u = vec![vec![0.0f64; n_clusters]; n];
757 for i in 0..n {
758 let dists: Vec<f64> = (0..n_clusters)
759 .map(|c| sq_dist(&data[i], ¢roids[c]).max(1e-30))
760 .collect();
761 let exact: Vec<usize> = dists
763 .iter()
764 .enumerate()
765 .filter(|(_, &d)| d < 1e-30)
766 .map(|(c, _)| c)
767 .collect();
768 if !exact.is_empty() {
769 let share = 1.0 / exact.len() as f64;
770 for &c in &exact {
771 new_u[i][c] = share;
772 }
773 } else {
774 for c in 0..n_clusters {
775 let ratio_sum: f64 = (0..n_clusters)
776 .map(|j| (dists[c] / dists[j]).powf(exp))
777 .sum();
778 new_u[i][c] = 1.0 / ratio_sum;
779 }
780 }
781 }
782
783 let new_centroids = compute_fuzzy_centroids(data, &new_u, n_clusters, m, d);
785
786 let max_change: f64 = centroids
788 .iter()
789 .zip(new_centroids.iter())
790 .map(|(c_old, c_new)| {
791 c_old
792 .iter()
793 .zip(c_new.iter())
794 .map(|(a, b)| (a - b).abs())
795 .fold(0.0f64, f64::max)
796 })
797 .fold(0.0f64, f64::max);
798
799 u = new_u;
800 centroids = new_centroids;
801
802 if max_change < tol {
803 break;
804 }
805 }
806
807 Ok((centroids, u))
808}
809
810fn compute_fuzzy_centroids(
811 data: &[Vec<f64>],
812 u: &[Vec<f64>],
813 n_clusters: usize,
814 m: f64,
815 d: usize,
816) -> Vec<Vec<f64>> {
817 (0..n_clusters)
818 .map(|c| {
819 let mut num = vec![0.0f64; d];
820 let mut denom = 0.0f64;
821 for (i, xi) in data.iter().enumerate() {
822 let uic_m = u[i][c].powf(m);
823 denom += uic_m;
824 for (f, &v) in num.iter_mut().zip(xi.iter()) {
825 *f += uic_m * v;
826 }
827 }
828 if denom.abs() < 1e-30 {
829 vec![0.0f64; d]
830 } else {
831 num.iter().map(|&v| v / denom).collect()
832 }
833 })
834 .collect()
835}
836
837#[cfg(test)]
842mod tests {
843 use super::*;
844
845 fn two_cluster_data() -> Vec<Vec<f64>> {
846 let mut v = Vec::new();
848 for i in 0..20 {
849 v.push(vec![i as f64 * 0.1, i as f64 * 0.1]);
850 }
851 for i in 0..20 {
852 v.push(vec![10.0 + i as f64 * 0.1, 10.0 + i as f64 * 0.1]);
853 }
854 v
855 }
856
857 #[test]
860 fn test_kernel_kmeans_rbf_two_clusters() {
861 let data = two_cluster_data();
862 let (labels, inertia) = kernel_kmeans(&data, 2, KernelType::Rbf { gamma: 0.5 }, 50, 3, 42)
863 .expect("kernel_kmeans should succeed");
864 assert_eq!(labels.len(), 40);
865 assert!(inertia.is_finite());
866 let l0 = labels[0];
868 let l20 = labels[20];
869 assert_ne!(l0, l20, "blobs should be in different clusters");
870 assert!(labels[..20].iter().all(|&l| l == l0));
871 assert!(labels[20..].iter().all(|&l| l == l20));
872 }
873
874 #[test]
875 fn test_kernel_kmeans_linear() {
876 let data = two_cluster_data();
877 let (labels, _) = kernel_kmeans(&data, 2, KernelType::Linear, 20, 2, 7)
878 .expect("kernel_kmeans linear should succeed");
879 assert_eq!(labels.len(), 40);
880 }
881
882 #[test]
883 fn test_kernel_kmeans_polynomial() {
884 let data = two_cluster_data();
885 let (labels, _) = kernel_kmeans(
886 &data,
887 2,
888 KernelType::Polynomial {
889 degree: 2,
890 coef0: 1.0,
891 gamma: 0.1,
892 },
893 20,
894 2,
895 99,
896 )
897 .expect("kernel_kmeans poly should succeed");
898 assert_eq!(labels.len(), 40);
899 }
900
901 #[test]
902 fn test_kernel_kmeans_invalid_inputs() {
903 let data = two_cluster_data();
904 assert!(kernel_kmeans(&[], 2, KernelType::Linear, 10, 1, 0).is_err());
905 assert!(kernel_kmeans(&data, 0, KernelType::Linear, 10, 1, 0).is_err());
906 assert!(kernel_kmeans(&data, 100, KernelType::Linear, 10, 1, 0).is_err());
907 }
908
909 #[test]
912 fn test_trimmed_kmeans_basic() {
913 let mut data = two_cluster_data();
914 data.push(vec![100.0, 100.0]);
916 data.push(vec![-100.0, -100.0]);
917
918 let (labels, centroids) =
919 trimmed_kmeans(&data, 2, 0.05, 100, 42).expect("trimmed_kmeans should succeed");
920 assert_eq!(labels.len(), data.len());
921 assert_eq!(centroids.len(), 2);
922 let trimmed_count = labels.iter().filter(|l| l.is_none()).count();
924 assert!(trimmed_count >= 1, "at least one outlier should be trimmed");
925 }
926
927 #[test]
928 fn test_trimmed_kmeans_no_trim() {
929 let data = two_cluster_data();
930 let (labels, centroids) = trimmed_kmeans(&data, 2, 0.0, 50, 0)
931 .expect("trimmed_kmeans with trim=0 should succeed");
932 assert_eq!(labels.len(), 40);
933 assert_eq!(centroids.len(), 2);
934 assert!(labels.iter().all(|l| l.is_some()));
936 }
937
938 #[test]
939 fn test_trimmed_kmeans_invalid() {
940 let data = two_cluster_data();
941 assert!(trimmed_kmeans(&[], 2, 0.1, 10, 0).is_err());
942 assert!(trimmed_kmeans(&data, 0, 0.1, 10, 0).is_err());
943 assert!(trimmed_kmeans(&data, 2, 0.6, 10, 0).is_err()); }
945
946 #[test]
949 fn test_dp_mixture_finds_clusters() {
950 let data = two_cluster_data();
951 let mut dpm = DpMixture::new(1.0);
952 let labels = dpm.fit(&data, 30, 42);
953 assert_eq!(labels.len(), 40);
954 assert!(dpm.n_clusters() >= 1);
956 assert!(!dpm.means.is_empty());
957 assert!(!dpm.weights.is_empty());
958 let weight_sum: f64 = dpm.weights.iter().sum();
959 assert!((weight_sum - 1.0).abs() < 1e-10);
960 }
961
962 #[test]
963 fn test_dp_mixture_empty() {
964 let mut dpm = DpMixture::new(1.0);
965 let labels = dpm.fit(&[], 10, 0);
966 assert!(labels.is_empty());
967 assert_eq!(dpm.n_clusters(), 0);
968 }
969
970 #[test]
973 fn test_fuzzy_cmeans_basic() {
974 let data = two_cluster_data();
975 let (centroids, membership) =
976 fuzzy_cmeans(&data, 2, 2.0, 100, 1e-6, 42).expect("fuzzy_cmeans should succeed");
977 assert_eq!(centroids.len(), 2);
978 assert_eq!(membership.len(), 40);
979 assert_eq!(membership[0].len(), 2);
980 for row in &membership {
982 let s: f64 = row.iter().sum();
983 assert!(
984 (s - 1.0).abs() < 1e-8,
985 "membership row must sum to 1, got {}",
986 s
987 );
988 }
989 }
990
991 #[test]
992 fn test_fuzzy_cmeans_high_fuzziness() {
993 let data = two_cluster_data();
994 let (_, membership) = fuzzy_cmeans(&data, 3, 3.5, 50, 1e-5, 99)
995 .expect("fuzzy_cmeans high fuzz should succeed");
996 for row in &membership {
997 let s: f64 = row.iter().sum();
998 assert!((s - 1.0).abs() < 1e-7);
999 }
1000 }
1001
1002 #[test]
1003 fn test_fuzzy_cmeans_invalid() {
1004 let data = two_cluster_data();
1005 assert!(fuzzy_cmeans(&[], 2, 2.0, 10, 1e-6, 0).is_err());
1006 assert!(fuzzy_cmeans(&data, 0, 2.0, 10, 1e-6, 0).is_err());
1007 assert!(fuzzy_cmeans(&data, 2, 1.0, 10, 1e-6, 0).is_err()); assert!(fuzzy_cmeans(&data, 2, 0.5, 10, 1e-6, 0).is_err()); assert!(fuzzy_cmeans(&data, 100, 2.0, 10, 1e-6, 0).is_err()); }
1011}