1use std::f64::consts::PI;
29
30use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31
32use crate::error::{ClusteringError, Result};
33
34fn digamma(x: f64) -> f64 {
40 if x <= 0.0 {
41 return f64::NEG_INFINITY;
42 }
43 let mut v = x;
44 let mut result = 0.0;
45 while v < 6.0 {
47 result -= 1.0 / v;
48 v += 1.0;
49 }
50 result += v.ln() - 0.5 / v;
52 let inv_v2 = 1.0 / (v * v);
53 result -= inv_v2 * (1.0 / 12.0 - inv_v2 * (1.0 / 120.0 - inv_v2 / 252.0));
54 result
55}
56
57fn logsumexp_row(row: &[f64]) -> f64 {
59 let max = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
60 if max.is_infinite() {
61 return f64::NEG_INFINITY;
62 }
63 let s: f64 = row.iter().map(|&v| (v - max).exp()).sum();
64 max + s.ln()
65}
66
67fn cholesky(a: &Array2<f64>) -> Result<Array2<f64>> {
70 let n = a.shape()[0];
71 let mut l = Array2::<f64>::zeros((n, n));
72 for i in 0..n {
73 for j in 0..=i {
74 let mut s = a[[i, j]];
75 for k in 0..j {
76 s -= l[[i, k]] * l[[j, k]];
77 }
78 if i == j {
79 if s <= 0.0 {
80 s = 1e-12;
81 }
82 l[[i, j]] = s.sqrt();
83 } else if l[[j, j]].abs() < 1e-15 {
84 l[[i, j]] = 0.0;
85 } else {
86 l[[i, j]] = s / l[[j, j]];
87 }
88 }
89 }
90 Ok(l)
91}
92
93fn log_det_pd(a: &Array2<f64>) -> Result<f64> {
95 let l = cholesky(a)?;
96 let n = l.shape()[0];
97 let mut log_det = 0.0;
98 for i in 0..n {
99 log_det += 2.0 * l[[i, i]].ln();
100 }
101 Ok(log_det)
102}
103
104fn cholesky_solve(l: &Array2<f64>, b: ArrayView1<f64>) -> Array1<f64> {
106 let n = l.shape()[0];
107 let mut y = Array1::<f64>::zeros(n);
108 for i in 0..n {
110 let mut s = b[i];
111 for k in 0..i {
112 s -= l[[i, k]] * y[k];
113 }
114 y[i] = if l[[i, i]].abs() < 1e-15 {
115 0.0
116 } else {
117 s / l[[i, i]]
118 };
119 }
120 let mut x = Array1::<f64>::zeros(n);
122 for i in (0..n).rev() {
123 let mut s = y[i];
124 for k in (i + 1)..n {
125 s -= l[[k, i]] * x[k];
126 }
127 x[i] = if l[[i, i]].abs() < 1e-15 {
128 0.0
129 } else {
130 s / l[[i, i]]
131 };
132 }
133 x
134}
135
136fn log_mvn(x: ArrayView1<f64>, mu: ArrayView1<f64>, l: &Array2<f64>) -> f64 {
139 let d = x.len() as f64;
140 let diff: Array1<f64> = x.iter().zip(mu.iter()).map(|(&xi, &mi)| xi - mi).collect();
141 let z = cholesky_solve(l, diff.view());
142 let maha: f64 = z.iter().map(|&v| v * v).sum();
143 let log_det_l: f64 = (0..l.shape()[0]).map(|i| l[[i, i]].ln()).sum::<f64>();
144 -0.5 * (d * (2.0 * PI).ln() + 2.0 * log_det_l + maha)
145}
146
147fn kmeans_pp_init(data: ArrayView2<f64>, k: usize, seed: u64) -> Array2<f64> {
149 let n = data.shape()[0];
150 let d = data.shape()[1];
151
152 let mut rng_state = seed;
153 let lcg = |s: u64| {
154 s.wrapping_mul(6364136223846793005)
155 .wrapping_add(1442695040888963407)
156 };
157 let rand_f64 = |s: &mut u64| -> f64 {
158 *s = lcg(*s);
159 (*s >> 11) as f64 / (1u64 << 53) as f64
160 };
161
162 let mut centers = Array2::<f64>::zeros((k, d));
163 rng_state = lcg(rng_state);
165 let first = (rng_state as usize) % n;
166 centers.row_mut(0).assign(&data.row(first));
167
168 for ci in 1..k {
169 let mut dists = Vec::with_capacity(n);
171 let mut sum_d = 0.0;
172 for i in 0..n {
173 let mut min_d2 = f64::INFINITY;
174 for cj in 0..ci {
175 let d2: f64 = data
176 .row(i)
177 .iter()
178 .zip(centers.row(cj).iter())
179 .map(|(&a, &b)| (a - b) * (a - b))
180 .sum();
181 if d2 < min_d2 {
182 min_d2 = d2;
183 }
184 }
185 dists.push(min_d2);
186 sum_d += min_d2;
187 }
188 let mut u = rand_f64(&mut rng_state) * sum_d;
190 let mut chosen = n - 1;
191 for (i, &d_i) in dists.iter().enumerate() {
192 u -= d_i;
193 if u <= 0.0 {
194 chosen = i;
195 break;
196 }
197 }
198 centers.row_mut(ci).assign(&data.row(chosen));
199 }
200 centers
201}
202
203#[derive(Debug, Clone)]
213pub struct GmmParams {
214 pub weights: Array1<f64>,
216 pub means: Array2<f64>,
218 pub chol_covs: Vec<Array2<f64>>,
220 pub n_iter: usize,
222 pub converged: bool,
224 pub log_likelihood: f64,
226}
227
228impl GmmParams {
229 pub fn n_components(&self) -> usize {
231 self.weights.len()
232 }
233
234 pub fn n_features(&self) -> usize {
236 self.means.shape()[1]
237 }
238
239 pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
243 let n = data.shape()[0];
244 let k = self.n_components();
245 let mut log_resp = Array2::<f64>::zeros((n, k));
246
247 for i in 0..n {
248 for c in 0..k {
249 if self.weights[c] <= 0.0 {
250 log_resp[[i, c]] = f64::NEG_INFINITY;
251 continue;
252 }
253 log_resp[[i, c]] = self.weights[c].ln()
254 + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]);
255 }
256 let row: Vec<f64> = (0..k).map(|c| log_resp[[i, c]]).collect();
258 let lse = logsumexp_row(&row);
259 for c in 0..k {
260 log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
261 }
262 }
263 Ok(log_resp)
264 }
265
266 pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
268 let proba = self.predict_proba(data)?;
269 let n = proba.shape()[0];
270 let k = proba.shape()[1];
271 let mut labels = Array1::<usize>::zeros(n);
272 for i in 0..n {
273 let mut best = 0;
274 let mut best_p = proba[[i, 0]];
275 for c in 1..k {
276 if proba[[i, c]] > best_p {
277 best_p = proba[[i, c]];
278 best = c;
279 }
280 }
281 labels[i] = best;
282 }
283 Ok(labels)
284 }
285
286 pub fn score(&self, data: ArrayView2<f64>) -> Result<f64> {
288 let n = data.shape()[0];
289 let k = self.n_components();
290 let mut total_ll = 0.0;
291 for i in 0..n {
292 let mut log_terms: Vec<f64> = Vec::with_capacity(k);
293 for c in 0..k {
294 if self.weights[c] > 0.0 {
295 log_terms.push(
296 self.weights[c].ln()
297 + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]),
298 );
299 }
300 }
301 total_ll += logsumexp_row(&log_terms);
302 }
303 Ok(total_ll / n as f64)
304 }
305
306 fn n_free_params(&self) -> usize {
311 let k = self.n_components();
312 let d = self.n_features();
313 (k - 1) + k * d + k * (d * (d + 1) / 2)
314 }
315
316 pub fn bic(&self, data: ArrayView2<f64>) -> Result<f64> {
320 let n = data.shape()[0] as f64;
321 let ll = self.score(data)? * n;
322 let p = self.n_free_params() as f64;
323 Ok(-2.0 * ll + p * n.ln())
324 }
325
326 pub fn aic(&self, data: ArrayView2<f64>) -> Result<f64> {
330 let n = data.shape()[0] as f64;
331 let ll = self.score(data)? * n;
332 let p = self.n_free_params() as f64;
333 Ok(-2.0 * ll + 2.0 * p)
334 }
335}
336
337pub struct GaussianMixtureModel;
348
349impl GaussianMixtureModel {
350 pub fn fit(
363 data: ArrayView2<f64>,
364 n_components: usize,
365 max_iter: usize,
366 tol: f64,
367 ) -> Result<GmmParams> {
368 let n = data.shape()[0];
369 let d = data.shape()[1];
370 let k = n_components;
371
372 if k == 0 {
373 return Err(ClusteringError::InvalidInput(
374 "n_components must be >= 1".to_string(),
375 ));
376 }
377 if n < k {
378 return Err(ClusteringError::InvalidInput(
379 "n_samples must be >= n_components".to_string(),
380 ));
381 }
382 if d == 0 {
383 return Err(ClusteringError::InvalidInput(
384 "n_features must be >= 1".to_string(),
385 ));
386 }
387
388 let reg = 1e-6_f64;
389
390 let init_means = kmeans_pp_init(data, k, 42);
392
393 let mut resp = Array2::<f64>::zeros((n, k));
395 for i in 0..n {
396 let mut best_c = 0;
397 let mut best_d = f64::INFINITY;
398 for c in 0..k {
399 let d2: f64 = data
400 .row(i)
401 .iter()
402 .zip(init_means.row(c).iter())
403 .map(|(&a, &b)| (a - b) * (a - b))
404 .sum();
405 if d2 < best_d {
406 best_d = d2;
407 best_c = c;
408 }
409 }
410 resp[[i, best_c]] = 1.0;
411 }
412
413 let (mut weights, mut means, mut chol_covs) = Self::m_step(data, resp.view(), k, d, reg)?;
415
416 let mut prev_ll = f64::NEG_INFINITY;
417 let mut n_iter = 0;
418 let mut converged = false;
419
420 for iter in 0..max_iter {
421 n_iter = iter + 1;
422
423 resp = Self::e_step(data, &weights, &means, &chol_covs, k)?;
425
426 let ll = Self::mean_log_likelihood(data, &weights, &means, &chol_covs, k);
428
429 if (ll - prev_ll).abs() < tol {
430 converged = true;
431 prev_ll = ll;
432 let (w, m, c) = Self::m_step(data, resp.view(), k, d, reg)?;
434 weights = w;
435 means = m;
436 chol_covs = c;
437 break;
438 }
439 prev_ll = ll;
440
441 let (w, m, c) = Self::m_step(data, resp.view(), k, d, reg)?;
443 weights = w;
444 means = m;
445 chol_covs = c;
446 }
447
448 Ok(GmmParams {
449 weights,
450 means,
451 chol_covs,
452 n_iter,
453 converged,
454 log_likelihood: prev_ll,
455 })
456 }
457
458 fn e_step(
461 data: ArrayView2<f64>,
462 weights: &Array1<f64>,
463 means: &Array2<f64>,
464 chol_covs: &[Array2<f64>],
465 k: usize,
466 ) -> Result<Array2<f64>> {
467 let n = data.shape()[0];
468 let mut log_resp = Array2::<f64>::zeros((n, k));
469
470 for i in 0..n {
471 for c in 0..k {
472 if weights[c] <= 0.0 {
473 log_resp[[i, c]] = f64::NEG_INFINITY;
474 continue;
475 }
476 log_resp[[i, c]] =
477 weights[c].ln() + log_mvn(data.row(i), means.row(c), &chol_covs[c]);
478 }
479 let row: Vec<f64> = (0..k).map(|c| log_resp[[i, c]]).collect();
480 let lse = logsumexp_row(&row);
481 for c in 0..k {
482 log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
483 }
484 }
485 Ok(log_resp)
486 }
487
488 fn m_step(
489 data: ArrayView2<f64>,
490 resp: ArrayView2<f64>,
491 k: usize,
492 d: usize,
493 reg: f64,
494 ) -> Result<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
495 let n = data.shape()[0];
496
497 let nk: Vec<f64> = (0..k)
499 .map(|c| (0..n).map(|i| resp[[i, c]]).sum::<f64>().max(1e-10))
500 .collect();
501
502 let total_n: f64 = nk.iter().sum();
503 let weights: Array1<f64> = nk.iter().map(|&nkc| nkc / total_n).collect();
504
505 let mut means = Array2::<f64>::zeros((k, d));
507 for c in 0..k {
508 for i in 0..n {
509 for f in 0..d {
510 means[[c, f]] += resp[[i, c]] * data[[i, f]];
511 }
512 }
513 for f in 0..d {
514 means[[c, f]] /= nk[c];
515 }
516 }
517
518 let mut chol_covs = Vec::with_capacity(k);
520 for c in 0..k {
521 let mut cov = Array2::<f64>::zeros((d, d));
522 for i in 0..n {
523 for f1 in 0..d {
524 let diff_f1 = data[[i, f1]] - means[[c, f1]];
525 for f2 in f1..d {
526 let diff_f2 = data[[i, f2]] - means[[c, f2]];
527 let v = resp[[i, c]] * diff_f1 * diff_f2 / nk[c];
528 cov[[f1, f2]] += v;
529 if f2 != f1 {
530 cov[[f2, f1]] += v;
531 }
532 }
533 }
534 }
535 for f in 0..d {
536 cov[[f, f]] += reg;
537 }
538 let l = cholesky(&cov)?;
539 chol_covs.push(l);
540 }
541
542 Ok((weights, means, chol_covs))
543 }
544
545 fn mean_log_likelihood(
546 data: ArrayView2<f64>,
547 weights: &Array1<f64>,
548 means: &Array2<f64>,
549 chol_covs: &[Array2<f64>],
550 k: usize,
551 ) -> f64 {
552 let n = data.shape()[0];
553 let mut total = 0.0;
554 for i in 0..n {
555 let mut log_terms: Vec<f64> = Vec::with_capacity(k);
556 for c in 0..k {
557 if weights[c] > 0.0 {
558 log_terms
559 .push(weights[c].ln() + log_mvn(data.row(i), means.row(c), &chol_covs[c]));
560 }
561 }
562 total += logsumexp_row(&log_terms);
563 }
564 total / n as f64
565 }
566}
567
568#[derive(Debug, Clone)]
574pub struct DpmmResult {
575 pub stick_weights: Array1<f64>,
577 pub means: Array2<f64>,
579 pub active: Vec<bool>,
582 pub elbo: f64,
584 pub n_iter: usize,
586 pub converged: bool,
588 chol_covs: Vec<Array2<f64>>,
590 n_active: usize,
591}
592
593impl DpmmResult {
594 pub fn n_components(&self) -> usize {
596 self.stick_weights.len()
597 }
598
599 pub fn n_active_components(&self) -> usize {
601 self.n_active
602 }
603
604 pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
606 let n = data.shape()[0];
607 let t = self.n_components();
608 let mut log_resp = Array2::<f64>::zeros((n, t));
609 for i in 0..n {
610 for c in 0..t {
611 let w = self.stick_weights[c];
612 if w <= 0.0 || !self.active[c] {
613 log_resp[[i, c]] = f64::NEG_INFINITY;
614 continue;
615 }
616 log_resp[[i, c]] =
617 w.ln() + log_mvn(data.row(i), self.means.row(c), &self.chol_covs[c]);
618 }
619 let row: Vec<f64> = (0..t).map(|c| log_resp[[i, c]]).collect();
620 let lse = logsumexp_row(&row);
621 for c in 0..t {
622 log_resp[[i, c]] = (log_resp[[i, c]] - lse).exp();
623 }
624 }
625 Ok(log_resp)
626 }
627
628 pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
630 let proba = self.predict_proba(data)?;
631 let n = proba.shape()[0];
632 let t = proba.shape()[1];
633 let mut labels = Array1::<usize>::zeros(n);
634 for i in 0..n {
635 let mut best = 0;
636 let mut best_p = proba[[i, 0]];
637 for c in 1..t {
638 if proba[[i, c]] > best_p {
639 best_p = proba[[i, c]];
640 best = c;
641 }
642 }
643 labels[i] = best;
644 }
645 Ok(labels)
646 }
647}
648
649pub struct DirichletProcessMixtureModel {
661 pub alpha: f64,
663 pub truncation: usize,
665 pub max_iter: usize,
667 pub tol: f64,
669 pub activity_threshold: f64,
671}
672
673impl DirichletProcessMixtureModel {
674 pub fn new(alpha: f64, truncation: usize) -> Self {
676 Self {
677 alpha,
678 truncation,
679 max_iter: 200,
680 tol: 1e-4,
681 activity_threshold: 1e-2,
682 }
683 }
684
685 pub fn fit(&self, data: ArrayView2<f64>) -> Result<DpmmResult> {
687 let n = data.shape()[0];
688 let d = data.shape()[1];
689 let t = self.truncation;
690
691 if n == 0 || d == 0 {
692 return Err(ClusteringError::InvalidInput(
693 "Data must be non-empty".to_string(),
694 ));
695 }
696 if t < 1 {
697 return Err(ClusteringError::InvalidInput(
698 "truncation must be >= 1".to_string(),
699 ));
700 }
701
702 let reg = 1e-6_f64;
703 let alpha = self.alpha;
704
705 let k_init = t.min(n);
707 let init_means = kmeans_pp_init(data, k_init, 7);
708
709 let mut phi = Array2::<f64>::zeros((n, t));
711 for i in 0..n {
712 let mut best_c = 0;
713 let mut best_d = f64::INFINITY;
714 for c in 0..k_init {
715 let d2: f64 = data
716 .row(i)
717 .iter()
718 .zip(init_means.row(c).iter())
719 .map(|(&a, &b)| (a - b) * (a - b))
720 .sum();
721 if d2 < best_d {
722 best_d = d2;
723 best_c = c;
724 }
725 }
726 phi[[i, best_c]] = 1.0;
727 }
728
729 let mut a_gamma = Array1::<f64>::from_elem(t, 1.0);
732 let mut b_gamma = Array1::<f64>::from_elem(t, alpha);
733
734 let mut m = Array2::<f64>::zeros((t, d)); let mut beta_k = Array1::<f64>::from_elem(t, 1.0); let mut nu_k = Array1::<f64>::from_elem(t, d as f64 + 1.0); let mut w_k = Array2::<f64>::from_elem((t, d), 1.0); for c in 0..k_init {
742 for f in 0..d {
743 m[[c, f]] = init_means[[c, f]];
744 }
745 }
746
747 let mut prev_elbo = f64::NEG_INFINITY;
748 let mut n_iter = 0;
749 let mut converged = false;
750
751 for iter in 0..self.max_iter {
752 n_iter = iter + 1;
753
754 let mut e_log_pi = Array1::<f64>::zeros(t);
757 let mut cumsum_b = 0.0;
758 for k in 0..t {
759 let e_log_v_k = digamma(a_gamma[k]) - digamma(a_gamma[k] + b_gamma[k]);
760 let e_log_1mv_k = digamma(b_gamma[k]) - digamma(a_gamma[k] + b_gamma[k]);
761 e_log_pi[k] = e_log_v_k + cumsum_b;
762 cumsum_b += e_log_1mv_k;
763 }
764
765 let e_log_lam: Vec<f64> = (0..t)
768 .map(|k| {
769 (0..d)
770 .map(|f| {
771 let dof_f = (nu_k[k] + 1.0 - f as f64) / 2.0;
772 digamma(dof_f.max(0.5)) + (2.0 * w_k[[k, f]]).ln()
773 })
774 .sum::<f64>()
775 })
776 .collect();
777
778 for i in 0..n {
779 let mut log_rho = Vec::with_capacity(t);
780 for k in 0..t {
781 let trace_term: f64 = (0..d)
783 .map(|f| {
784 nu_k[k] * w_k[[k, f]] * (data[[i, f]] - m[[k, f]]).powi(2)
785 + 1.0 / beta_k[k]
786 })
787 .sum();
788 log_rho.push(
789 e_log_pi[k] + 0.5 * e_log_lam[k]
790 - 0.5 * d as f64 * (2.0 * PI).ln()
791 - 0.5 * trace_term,
792 );
793 }
794 let lse = logsumexp_row(&log_rho);
795 for k in 0..t {
796 phi[[i, k]] = (log_rho[k] - lse).exp();
797 }
798 }
799
800 let nk: Vec<f64> = (0..t)
802 .map(|k| (0..n).map(|i| phi[[i, k]]).sum::<f64>().max(1e-10))
803 .collect();
804
805 for k in 0..t {
806 let sum_after: f64 = nk[(k + 1)..].iter().sum();
807 a_gamma[k] = 1.0 + nk[k];
808 b_gamma[k] = alpha + sum_after;
809 }
810
811 for k in 0..t {
813 let beta_0 = 1.0;
814 let nu_0 = d as f64 + 1.0;
815
816 beta_k[k] = beta_0 + nk[k];
818 let mut x_bar = vec![0.0_f64; d];
819 for i in 0..n {
820 for f in 0..d {
821 x_bar[f] += phi[[i, k]] * data[[i, f]];
822 }
823 }
824 for f in 0..d {
825 x_bar[f] /= nk[k];
826 m[[k, f]] = (beta_0 * 0.0 + nk[k] * x_bar[f]) / beta_k[k];
827 }
828
829 nu_k[k] = nu_0 + nk[k];
831
832 for f in 0..d {
834 let mut scatter = 0.0;
835 for i in 0..n {
836 scatter += phi[[i, k]] * (data[[i, f]] - x_bar[f]).powi(2);
837 }
838 let bc_correction = beta_0 * nk[k] / beta_k[k] * x_bar[f].powi(2);
839 w_k[[k, f]] = 1.0 / (1.0 / (1.0 + reg) + scatter + bc_correction);
840 }
841 }
842
843 let elbo = Self::compute_elbo(
845 data, &phi, &a_gamma, &b_gamma, &m, &beta_k, &nu_k, &w_k, alpha, n, d, t,
846 );
847
848 if (elbo - prev_elbo).abs() < self.tol {
849 converged = true;
850 prev_elbo = elbo;
851 break;
852 }
853 prev_elbo = elbo;
854 }
855
856 let mut expected_weights = Array1::<f64>::zeros(t);
859 let mut log_remaining: f64 = 0.0;
860 for k in 0..t {
861 let e_v_k = a_gamma[k] / (a_gamma[k] + b_gamma[k]);
862 expected_weights[k] = e_v_k * log_remaining.exp();
863 log_remaining += (1.0 - e_v_k).ln();
864 }
865
866 let active: Vec<bool> = (0..t)
867 .map(|k| expected_weights[k] > self.activity_threshold / t as f64)
868 .collect();
869 let n_active = active.iter().filter(|&&a| a).count();
870
871 let mut chol_covs = Vec::with_capacity(t);
873 for k in 0..t {
874 let mut cov = Array2::<f64>::zeros((d, d));
876 for f in 0..d {
877 let var = (1.0 / (nu_k[k] * w_k[[k, f]])).max(reg);
878 cov[[f, f]] = var.sqrt(); }
880 chol_covs.push(cov);
881 }
882
883 let final_means = m.clone();
884
885 Ok(DpmmResult {
886 stick_weights: expected_weights,
887 means: final_means,
888 active,
889 elbo: prev_elbo,
890 n_iter,
891 converged,
892 chol_covs,
893 n_active,
894 })
895 }
896
897 #[allow(clippy::too_many_arguments)]
899 fn compute_elbo(
900 data: ArrayView2<f64>,
901 phi: &Array2<f64>,
902 a_gamma: &Array1<f64>,
903 b_gamma: &Array1<f64>,
904 m: &Array2<f64>,
905 beta_k: &Array1<f64>,
906 nu_k: &Array1<f64>,
907 w_k: &Array2<f64>,
908 alpha: f64,
909 n: usize,
910 d: usize,
911 t: usize,
912 ) -> f64 {
913 let mut ll = 0.0;
915 for i in 0..n {
916 for k in 0..t {
917 if phi[[i, k]] < 1e-15 {
918 continue;
919 }
920 let log_norm = -(d as f64) / 2.0 * (2.0 * PI).ln();
921 let neg_quad: f64 = -(0..d)
922 .map(|f| nu_k[k] * w_k[[k, f]] * (data[[i, f]] - m[[k, f]]).powi(2))
923 .sum::<f64>()
924 / 2.0;
925 let e_log_lam: f64 = (0..d)
926 .map(|f| {
927 let dof_f = (nu_k[k] + 1.0 - f as f64) / 2.0;
928 digamma(dof_f.max(0.5)) + (2.0 * w_k[[k, f]]).ln()
929 })
930 .sum::<f64>()
931 / 2.0;
932 ll += phi[[i, k]] * (log_norm + e_log_lam + neg_quad);
933 }
934 }
935
936 let mut z_term = 0.0;
938 for i in 0..n {
939 for k in 0..t {
940 let phi_ik = phi[[i, k]];
941 if phi_ik > 1e-15 {
942 z_term -= phi_ik * phi_ik.ln(); }
944 }
945 }
946
947 let dp_term: f64 = (0..t)
949 .map(|k| (alpha - 1.0) * (digamma(b_gamma[k]) - digamma(a_gamma[k] + b_gamma[k])))
950 .sum();
951
952 let beta_entropy: f64 = (0..t)
954 .map(|k| {
955 let ab = a_gamma[k] + b_gamma[k];
956 let ent = (beta_k[k]).ln() - (a_gamma[k] - 1.0) * digamma(a_gamma[k]) + (ab).ln()
957 - (b_gamma[k] - 1.0) * digamma(b_gamma[k])
958 + digamma(ab);
959 ent
960 })
961 .sum();
962
963 ll + z_term + dp_term + beta_entropy
964 }
965}
966
967#[cfg(test)]
970mod tests {
971 use super::*;
972 use scirs2_core::ndarray::Array2;
973
974 fn two_cluster_data() -> Array2<f64> {
975 Array2::from_shape_vec(
976 (12, 2),
977 vec![
978 1.0, 1.0, 1.1, 0.9, 0.9, 1.1, 1.0, 1.0, 0.8, 1.2, 1.2, 0.8, 5.0, 5.0, 5.1, 4.9,
979 4.9, 5.1, 5.0, 5.0, 4.8, 5.2, 5.2, 4.8,
980 ],
981 )
982 .expect("data")
983 }
984
985 #[test]
986 fn test_gmm_fit_basic() {
987 let data = two_cluster_data();
988 let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
989 assert_eq!(params.n_components(), 2);
990 assert_eq!(params.n_features(), 2);
991 assert!(params.converged || params.n_iter > 0);
992 }
993
994 #[test]
995 fn test_gmm_predict_proba() {
996 let data = two_cluster_data();
997 let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
998 let proba = params.predict_proba(data.view()).expect("predict_proba");
999 assert_eq!(proba.shape(), [12, 2]);
1000 for i in 0..12 {
1002 let row_sum: f64 = (0..2).map(|c| proba[[i, c]]).sum();
1003 assert!((row_sum - 1.0).abs() < 1e-6, "row {i} sums to {row_sum}");
1004 }
1005 }
1006
1007 #[test]
1008 fn test_gmm_predict_hard() {
1009 let data = two_cluster_data();
1010 let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
1011 let labels = params.predict(data.view()).expect("predict");
1012 assert_eq!(labels.len(), 12);
1013 let unique: std::collections::HashSet<_> = labels.iter().copied().collect();
1015 assert!(unique.len() <= 2);
1016 }
1017
1018 #[test]
1019 fn test_gmm_score_finite() {
1020 let data = two_cluster_data();
1021 let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
1022 let score = params.score(data.view()).expect("score");
1023 assert!(score.is_finite(), "score must be finite, got {score}");
1024 }
1025
1026 #[test]
1027 fn test_gmm_bic_aic() {
1028 let data = two_cluster_data();
1029 let params = GaussianMixtureModel::fit(data.view(), 2, 100, 1e-4).expect("gmm fit");
1030 let bic = params.bic(data.view()).expect("bic");
1031 let aic = params.aic(data.view()).expect("aic");
1032 assert!(bic.is_finite());
1033 assert!(aic.is_finite());
1034 }
1037
1038 #[test]
1039 fn test_gmm_k1_trivial() {
1040 let data = two_cluster_data();
1041 let params = GaussianMixtureModel::fit(data.view(), 1, 50, 1e-4).expect("gmm k=1");
1042 let labels = params.predict(data.view()).expect("predict k=1");
1043 assert!(labels.iter().all(|&l| l == 0));
1045 }
1046
1047 #[test]
1048 fn test_gmm_invalid_k() {
1049 let data = two_cluster_data();
1050 let result = GaussianMixtureModel::fit(data.view(), 0, 50, 1e-4);
1051 assert!(result.is_err());
1052 }
1053
1054 #[test]
1055 fn test_dpmm_fit_basic() {
1056 let data = two_cluster_data();
1057 let model = DirichletProcessMixtureModel::new(1.0, 6);
1058 let result = model.fit(data.view()).expect("dpmm fit");
1059 assert_eq!(result.n_components(), 6);
1060 assert!(result.n_iter > 0);
1061 assert!(result.n_active_components() >= 1);
1063 }
1064
1065 #[test]
1066 fn test_dpmm_predict_proba() {
1067 let data = two_cluster_data();
1068 let model = DirichletProcessMixtureModel::new(1.0, 4);
1069 let result = model.fit(data.view()).expect("dpmm fit");
1070 let proba = result.predict_proba(data.view()).expect("proba");
1071 assert_eq!(proba.shape()[0], 12);
1072 assert_eq!(proba.shape()[1], 4);
1073 for i in 0..12 {
1074 let row_sum: f64 = (0..4).map(|c| proba[[i, c]]).sum();
1075 assert!((row_sum - 1.0).abs() < 1e-5, "row {i} sum {row_sum}");
1076 }
1077 }
1078
1079 #[test]
1080 fn test_dpmm_predict_hard() {
1081 let data = two_cluster_data();
1082 let model = DirichletProcessMixtureModel::new(1.0, 4);
1083 let result = model.fit(data.view()).expect("dpmm fit");
1084 let labels = result.predict(data.view()).expect("predict");
1085 assert_eq!(labels.len(), 12);
1086 }
1087
1088 #[test]
1089 fn test_dpmm_alpha_concentration() {
1090 let data = two_cluster_data();
1092 let model_low = DirichletProcessMixtureModel::new(0.01, 8);
1093 let model_high = DirichletProcessMixtureModel::new(10.0, 8);
1094 let r_low = model_low.fit(data.view()).expect("low alpha");
1095 let r_high = model_high.fit(data.view()).expect("high alpha");
1096 assert!(r_high.n_active_components() >= r_low.n_active_components());
1097 }
1098}