1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::{Rng, SeedableRng};
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, Untrained},
13 types::Float,
14};
15use std::f64::consts::PI;
16
17use crate::common::CovarianceType;
18
19fn log_sum_exp(a: f64, b: f64) -> f64 {
21 let max_val = a.max(b);
22 if max_val.is_finite() {
23 max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
24 } else {
25 max_val
26 }
27}
28
29#[derive(Debug, Clone)]
53pub struct ChineseRestaurantProcess<S = Untrained> {
54 state: S,
55 alpha: f64,
57 max_components: usize,
59 covariance_type: CovarianceType,
61 tol: f64,
63 max_iter: usize,
65 random_state: Option<u64>,
67 reg_covar: f64,
69}
70
71impl ChineseRestaurantProcess<Untrained> {
72 pub fn new() -> Self {
74 Self {
75 state: Untrained,
76 alpha: 1.0,
77 max_components: 20,
78 covariance_type: CovarianceType::Full,
79 tol: 1e-3,
80 max_iter: 100,
81 random_state: None,
82 reg_covar: 1e-6,
83 }
84 }
85
86 pub fn alpha(mut self, alpha: f64) -> Self {
88 self.alpha = alpha;
89 self
90 }
91
92 pub fn max_components(mut self, max_components: usize) -> Self {
94 self.max_components = max_components;
95 self
96 }
97
98 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
100 self.covariance_type = covariance_type;
101 self
102 }
103
104 pub fn tol(mut self, tol: f64) -> Self {
106 self.tol = tol;
107 self
108 }
109
110 pub fn max_iter(mut self, max_iter: usize) -> Self {
112 self.max_iter = max_iter;
113 self
114 }
115
116 pub fn random_state(mut self, random_state: u64) -> Self {
118 self.random_state = Some(random_state);
119 self
120 }
121
122 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
124 self.reg_covar = reg_covar;
125 self
126 }
127
128 fn initialize_tables(
130 &self,
131 X: &Array2<f64>,
132 rng: &mut scirs2_core::random::rngs::StdRng,
133 ) -> SklResult<(Array1<usize>, Array1<usize>, usize)> {
134 let n_samples = X.nrows();
135 let mut table_assignments = Array1::zeros(n_samples);
136 let mut table_counts: Array1<usize> = Array1::zeros(self.max_components);
137 let mut n_tables = 1;
138
139 table_assignments[0] = 0;
141 table_counts[0] = 1;
142
143 for i in 1..n_samples {
145 let total_customers = i;
146 let mut probabilities = Array1::zeros(n_tables + 1);
147
148 for k in 0..n_tables {
150 probabilities[k] = table_counts[k] as f64 / (total_customers as f64 + self.alpha);
151 }
152
153 if n_tables < self.max_components {
155 probabilities[n_tables] = self.alpha / (total_customers as f64 + self.alpha);
156 }
157
158 let cumsum: f64 = probabilities.iter().sum();
160 let mut cumulative = 0.0;
161 let target = rng.gen::<f64>() * cumsum;
162
163 let mut chosen_table = 0;
164 for k in 0..=n_tables {
165 cumulative += probabilities[k];
166 if target <= cumulative {
167 chosen_table = k;
168 break;
169 }
170 }
171
172 table_assignments[i] = chosen_table;
174 table_counts[chosen_table] += 1;
175
176 if chosen_table == n_tables && n_tables < self.max_components {
178 n_tables += 1;
179 }
180 }
181
182 Ok((table_assignments, table_counts, n_tables))
183 }
184
185 fn compute_parameters(
187 &self,
188 X: &Array2<f64>,
189 table_assignments: &Array1<usize>,
190 n_tables: usize,
191 ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
192 let (n_samples, n_features) = X.dim();
193
194 let mut weights = Array1::zeros(n_tables);
196 for &assignment in table_assignments.iter() {
197 if assignment < n_tables {
198 weights[assignment] += 1.0;
199 }
200 }
201 weights /= n_samples as f64;
202
203 let mut means = Array2::zeros((n_tables, n_features));
205 let mut counts: Array1<f64> = Array1::zeros(n_tables);
206
207 for (i, &assignment) in table_assignments.iter().enumerate() {
208 if assignment < n_tables {
209 for j in 0..n_features {
210 means[[assignment, j]] += X[[i, j]];
211 }
212 counts[assignment] += 1.0;
213 }
214 }
215
216 for k in 0..n_tables {
218 if counts[k] > 0.0 {
219 for j in 0..n_features {
220 means[[k, j]] /= counts[k];
221 }
222 }
223 }
224
225 let mut covariances = Vec::new();
227 for k in 0..n_tables {
228 let mut cov = Array2::zeros((n_features, n_features));
229 let mut count = 0.0;
230
231 for (i, &assignment) in table_assignments.iter().enumerate() {
232 if assignment == k {
233 let diff = &X.row(i) - &means.row(k);
234 for j in 0..n_features {
235 for l in 0..n_features {
236 cov[[j, l]] += diff[j] * diff[l];
237 }
238 }
239 count += 1.0;
240 }
241 }
242
243 if count > 1.0 {
244 cov /= count - 1.0;
245 } else {
246 for j in 0..n_features {
248 cov[[j, j]] = 1.0;
249 }
250 }
251
252 cov = self.regularize_covariance(cov)?;
254 covariances.push(cov);
255 }
256
257 Ok((weights, means, covariances))
258 }
259
260 fn regularize_covariance(&self, mut cov: Array2<f64>) -> SklResult<Array2<f64>> {
262 let n_features = cov.dim().0;
263
264 match self.covariance_type {
265 CovarianceType::Full => {
266 for i in 0..n_features {
268 cov[[i, i]] += self.reg_covar;
269 }
270 }
271 CovarianceType::Diagonal => {
272 for i in 0..n_features {
274 for j in 0..n_features {
275 if i != j {
276 cov[[i, j]] = 0.0;
277 } else {
278 cov[[i, i]] += self.reg_covar;
279 }
280 }
281 }
282 }
283 CovarianceType::Tied => {
284 for i in 0..n_features {
286 cov[[i, i]] += self.reg_covar;
287 }
288 }
289 CovarianceType::Spherical => {
290 let trace = (0..n_features).map(|i| cov[[i, i]]).sum::<f64>() / n_features as f64;
292 cov.fill(0.0);
293 for i in 0..n_features {
294 cov[[i, i]] = trace + self.reg_covar;
295 }
296 }
297 }
298
299 Ok(cov)
300 }
301
302 fn compute_log_likelihood(
304 &self,
305 X: &Array2<f64>,
306 weights: &Array1<f64>,
307 means: &Array2<f64>,
308 covariances: &[Array2<f64>],
309 ) -> SklResult<f64> {
310 let n_samples = X.nrows();
311 let n_components = weights.len();
312 let mut log_likelihood = 0.0;
313
314 for i in 0..n_samples {
315 let sample = X.row(i);
316 let mut sample_likelihood = 0.0;
317
318 for k in 0..n_components {
319 if weights[k] > 1e-10 {
320 let log_pdf =
321 self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
322 sample_likelihood += weights[k] * log_pdf.exp();
323 }
324 }
325
326 if sample_likelihood > 1e-300 {
327 log_likelihood += sample_likelihood.ln();
328 }
329 }
330
331 Ok(log_likelihood)
332 }
333
334 fn multivariate_normal_log_pdf(
336 &self,
337 x: &ArrayView1<f64>,
338 mean: &ArrayView1<f64>,
339 cov: &Array2<f64>,
340 ) -> SklResult<f64> {
341 let d = x.len() as f64;
342 let diff = x - mean;
343
344 let det = self.matrix_determinant(cov)?;
346 if det <= 0.0 {
347 return Ok(f64::NEG_INFINITY);
348 }
349
350 let inv_cov = self.matrix_inverse(cov)?;
351 let mut quad_form = 0.0;
352 for i in 0..diff.len() {
353 for j in 0..diff.len() {
354 quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
355 }
356 }
357
358 Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
359 }
360
361 fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
363 let n = A.dim().0;
364 if n == 1 {
365 return Ok(A[[0, 0]]);
366 }
367 if n == 2 {
368 return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
369 }
370
371 let mut det = 1.0;
373 let mut A_copy = A.clone();
374
375 for i in 0..n {
376 let mut max_row = i;
377 for k in i + 1..n {
378 if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
379 max_row = k;
380 }
381 }
382
383 if max_row != i {
384 for j in 0..n {
385 let temp = A_copy[[i, j]];
386 A_copy[[i, j]] = A_copy[[max_row, j]];
387 A_copy[[max_row, j]] = temp;
388 }
389 det *= -1.0;
390 }
391
392 if A_copy[[i, i]].abs() < 1e-12 {
393 return Ok(0.0);
394 }
395
396 det *= A_copy[[i, i]];
397
398 for k in i + 1..n {
399 let factor = A_copy[[k, i]] / A_copy[[i, i]];
400 for j in i..n {
401 A_copy[[k, j]] -= factor * A_copy[[i, j]];
402 }
403 }
404 }
405
406 Ok(det)
407 }
408
409 fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
411 let n = A.dim().0;
412 let mut aug = Array2::zeros((n, 2 * n));
413
414 for i in 0..n {
416 for j in 0..n {
417 aug[[i, j]] = A[[i, j]];
418 aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
419 }
420 }
421
422 for i in 0..n {
424 let mut max_row = i;
426 for k in i + 1..n {
427 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
428 max_row = k;
429 }
430 }
431
432 if max_row != i {
434 for j in 0..2 * n {
435 let temp = aug[[i, j]];
436 aug[[i, j]] = aug[[max_row, j]];
437 aug[[max_row, j]] = temp;
438 }
439 }
440
441 if aug[[i, i]].abs() < 1e-12 {
443 return Err(SklearsError::NumericalError(
444 "Matrix is singular".to_string(),
445 ));
446 }
447
448 let pivot = aug[[i, i]];
450 for j in 0..2 * n {
451 aug[[i, j]] /= pivot;
452 }
453
454 for k in 0..n {
456 if k != i {
457 let factor = aug[[k, i]];
458 for j in 0..2 * n {
459 aug[[k, j]] -= factor * aug[[i, j]];
460 }
461 }
462 }
463 }
464
465 let mut inv = Array2::zeros((n, n));
467 for i in 0..n {
468 for j in 0..n {
469 inv[[i, j]] = aug[[i, j + n]];
470 }
471 }
472
473 Ok(inv)
474 }
475}
476
477impl Default for ChineseRestaurantProcess<Untrained> {
478 fn default() -> Self {
479 Self::new()
480 }
481}
482
483impl Estimator for ChineseRestaurantProcess<Untrained> {
484 type Config = ();
485 type Error = SklearsError;
486 type Float = Float;
487
488 fn config(&self) -> &Self::Config {
489 &()
490 }
491}
492
493impl Fit<ArrayView2<'_, Float>, ()> for ChineseRestaurantProcess<Untrained> {
494 type Fitted = ChineseRestaurantProcess<ChineseRestaurantProcessTrained>;
495
496 #[allow(non_snake_case)]
497 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
498 let X = X.to_owned();
499 let (n_samples, n_features) = X.dim();
500
501 if n_samples == 0 {
502 return Err(SklearsError::InvalidInput(
503 "No samples provided".to_string(),
504 ));
505 }
506
507 let mut rng = match self.random_state {
508 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
509 None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
510 };
511
512 let (mut table_assignments, mut table_counts, mut n_tables) =
514 self.initialize_tables(&X, &mut rng)?;
515
516 let mut prev_log_likelihood = f64::NEG_INFINITY;
517 let mut converged = false;
518 let mut n_iter = 0;
519
520 for iteration in 0..self.max_iter {
522 n_iter = iteration + 1;
523
524 for i in 0..n_samples {
526 let current_table = table_assignments[i];
528 table_counts[current_table] -= 1;
529
530 if table_counts[current_table] == 0 && current_table == n_tables - 1 {
532 n_tables -= 1;
533 }
534
535 let mut probabilities = Array1::zeros(n_tables + 1);
537 let remaining_customers = n_samples - 1;
538
539 for k in 0..n_tables {
541 if table_counts[k] > 0 {
542 probabilities[k] =
543 table_counts[k] as f64 / (remaining_customers as f64 + self.alpha);
544 }
545 }
546
547 if n_tables < self.max_components {
549 probabilities[n_tables] =
550 self.alpha / (remaining_customers as f64 + self.alpha);
551 }
552
553 let cumsum: f64 = probabilities.iter().sum();
555 let mut cumulative = 0.0;
556 let target = rng.gen::<f64>() * cumsum;
557
558 let mut new_table = 0;
559 for k in 0..=n_tables {
560 cumulative += probabilities[k];
561 if target <= cumulative {
562 new_table = k;
563 break;
564 }
565 }
566
567 table_assignments[i] = new_table;
569 table_counts[new_table] += 1;
570
571 if new_table == n_tables && n_tables < self.max_components {
573 n_tables += 1;
574 }
575 }
576
577 if iteration % 5 == 0 {
579 let (weights, means, covariances) =
580 self.compute_parameters(&X, &table_assignments, n_tables)?;
581 let log_likelihood =
582 self.compute_log_likelihood(&X, &weights, &means, &covariances)?;
583
584 if (log_likelihood - prev_log_likelihood).abs() < self.tol {
585 converged = true;
586 break;
587 }
588 prev_log_likelihood = log_likelihood;
589 }
590 }
591
592 let (weights, means, covariances) =
594 self.compute_parameters(&X, &table_assignments, n_tables)?;
595 let log_likelihood = self.compute_log_likelihood(&X, &weights, &means, &covariances)?;
596
597 Ok(ChineseRestaurantProcess {
598 state: ChineseRestaurantProcessTrained {
599 n_components: n_tables,
600 weights,
601 means,
602 covariances,
603 covariance_type: self.covariance_type.clone(),
604 n_features,
605 alpha: self.alpha,
606 table_assignments,
607 table_counts: table_counts.slice(s![..n_tables]).to_owned(),
608 log_likelihood,
609 n_iter,
610 converged,
611 reg_covar: self.reg_covar,
612 },
613 alpha: self.alpha,
614 max_components: self.max_components,
615 covariance_type: self.covariance_type,
616 tol: self.tol,
617 max_iter: self.max_iter,
618 random_state: self.random_state,
619 reg_covar: self.reg_covar,
620 })
621 }
622}
623
624#[derive(Debug, Clone)]
626pub struct ChineseRestaurantProcessTrained {
627 pub n_components: usize,
629 pub weights: Array1<f64>,
631 pub means: Array2<f64>,
633 pub covariances: Vec<Array2<f64>>,
635 pub covariance_type: CovarianceType,
637 pub n_features: usize,
639 pub alpha: f64,
641 pub table_assignments: Array1<usize>,
643 pub table_counts: Array1<usize>,
645 pub log_likelihood: f64,
647 pub n_iter: usize,
649 pub converged: bool,
651 pub reg_covar: f64,
653}
654
655impl Predict<ArrayView2<'_, Float>, Array1<i32>>
656 for ChineseRestaurantProcess<ChineseRestaurantProcessTrained>
657{
658 #[allow(non_snake_case)]
659 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
660 let X = X.to_owned();
661 let (n_samples, _) = X.dim();
662 let mut predictions = Array1::zeros(n_samples);
663
664 for i in 0..n_samples {
665 let sample = X.row(i);
666 let mut max_log_prob = f64::NEG_INFINITY;
667 let mut best_component = 0;
668
669 for k in 0..self.state.n_components {
670 let log_weight = self.state.weights[k].ln();
671 let log_pdf = self.multivariate_normal_log_pdf(
672 &sample,
673 &self.state.means.row(k),
674 &self.state.covariances[k],
675 )?;
676 let log_prob = log_weight + log_pdf;
677
678 if log_prob > max_log_prob {
679 max_log_prob = log_prob;
680 best_component = k;
681 }
682 }
683
684 predictions[i] = best_component as i32;
685 }
686
687 Ok(predictions)
688 }
689}
690
691impl ChineseRestaurantProcess<ChineseRestaurantProcessTrained> {
692 #[allow(non_snake_case)]
694 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
695 let X = X.to_owned();
696 let (n_samples, _) = X.dim();
697 let mut probabilities = Array2::zeros((n_samples, self.state.n_components));
698
699 for i in 0..n_samples {
700 let sample = X.row(i);
701 let mut log_probs = Array1::zeros(self.state.n_components);
702
703 for k in 0..self.state.n_components {
704 let log_weight = self.state.weights[k].ln();
705 let log_pdf = self.multivariate_normal_log_pdf(
706 &sample,
707 &self.state.means.row(k),
708 &self.state.covariances[k],
709 )?;
710 log_probs[k] = log_weight + log_pdf;
711 }
712
713 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
715 let log_sum_exp = max_log_prob
716 + log_probs
717 .iter()
718 .map(|&lp| (lp - max_log_prob).exp())
719 .sum::<f64>()
720 .ln();
721
722 for k in 0..self.state.n_components {
723 probabilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
724 }
725 }
726
727 Ok(probabilities)
728 }
729
730 #[allow(non_snake_case)]
732 pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
733 let X = X.to_owned();
734 self.compute_log_likelihood(
735 &X,
736 &self.state.weights,
737 &self.state.means,
738 &self.state.covariances,
739 )
740 }
741
742 fn multivariate_normal_log_pdf(
744 &self,
745 x: &ArrayView1<f64>,
746 mean: &ArrayView1<f64>,
747 cov: &Array2<f64>,
748 ) -> SklResult<f64> {
749 let d = x.len() as f64;
750 let diff = x - mean;
751
752 let det = self.matrix_determinant(cov)?;
753 if det <= 0.0 {
754 return Ok(f64::NEG_INFINITY);
755 }
756
757 let inv_cov = self.matrix_inverse(cov)?;
758 let mut quad_form = 0.0;
759 for i in 0..diff.len() {
760 for j in 0..diff.len() {
761 quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
762 }
763 }
764
765 Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
766 }
767
768 fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
770 let n = A.dim().0;
771 if n == 1 {
772 return Ok(A[[0, 0]]);
773 }
774 if n == 2 {
775 return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
776 }
777
778 let mut det = 1.0;
779 let mut A_copy = A.clone();
780
781 for i in 0..n {
782 let mut max_row = i;
783 for k in i + 1..n {
784 if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
785 max_row = k;
786 }
787 }
788
789 if max_row != i {
790 for j in 0..n {
791 let temp = A_copy[[i, j]];
792 A_copy[[i, j]] = A_copy[[max_row, j]];
793 A_copy[[max_row, j]] = temp;
794 }
795 det *= -1.0;
796 }
797
798 if A_copy[[i, i]].abs() < 1e-12 {
799 return Ok(0.0);
800 }
801
802 det *= A_copy[[i, i]];
803
804 for k in i + 1..n {
805 let factor = A_copy[[k, i]] / A_copy[[i, i]];
806 for j in i..n {
807 A_copy[[k, j]] -= factor * A_copy[[i, j]];
808 }
809 }
810 }
811
812 Ok(det)
813 }
814
815 fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
816 let n = A.dim().0;
817 let mut aug = Array2::zeros((n, 2 * n));
818
819 for i in 0..n {
820 for j in 0..n {
821 aug[[i, j]] = A[[i, j]];
822 aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
823 }
824 }
825
826 for i in 0..n {
827 let mut max_row = i;
828 for k in i + 1..n {
829 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
830 max_row = k;
831 }
832 }
833
834 if max_row != i {
835 for j in 0..2 * n {
836 let temp = aug[[i, j]];
837 aug[[i, j]] = aug[[max_row, j]];
838 aug[[max_row, j]] = temp;
839 }
840 }
841
842 if aug[[i, i]].abs() < 1e-12 {
843 return Err(SklearsError::NumericalError(
844 "Matrix is singular".to_string(),
845 ));
846 }
847
848 let pivot = aug[[i, i]];
849 for j in 0..2 * n {
850 aug[[i, j]] /= pivot;
851 }
852
853 for k in 0..n {
854 if k != i {
855 let factor = aug[[k, i]];
856 for j in 0..2 * n {
857 aug[[k, j]] -= factor * aug[[i, j]];
858 }
859 }
860 }
861 }
862
863 let mut inv = Array2::zeros((n, n));
864 for i in 0..n {
865 for j in 0..n {
866 inv[[i, j]] = aug[[i, j + n]];
867 }
868 }
869
870 Ok(inv)
871 }
872
873 fn compute_log_likelihood(
874 &self,
875 X: &Array2<f64>,
876 weights: &Array1<f64>,
877 means: &Array2<f64>,
878 covariances: &[Array2<f64>],
879 ) -> SklResult<f64> {
880 let n_samples = X.nrows();
881 let n_components = weights.len();
882 let mut log_likelihood = 0.0;
883
884 for i in 0..n_samples {
885 let sample = X.row(i);
886 let mut sample_likelihood = 0.0;
887
888 for k in 0..n_components {
889 if weights[k] > 1e-10 {
890 let log_pdf =
891 self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
892 sample_likelihood += weights[k] * log_pdf.exp();
893 }
894 }
895
896 if sample_likelihood > 1e-300 {
897 log_likelihood += sample_likelihood.ln();
898 }
899 }
900
901 Ok(log_likelihood)
902 }
903}
904
905#[derive(Debug, Clone)]
910pub struct DirichletProcessGaussianMixture<S = Untrained> {
911 state: S,
912 pub alpha: f64,
914 pub max_components: usize,
916 pub covariance_type: CovarianceType,
918 pub tol: f64,
920 pub max_iter: usize,
922 pub reg_covar: f64,
924 pub random_state: Option<u64>,
926 pub n_init: usize,
928}
929
930impl DirichletProcessGaussianMixture<Untrained> {
931 pub fn new() -> Self {
933 Self {
934 state: Untrained,
935 alpha: 1.0,
936 max_components: 20,
937 covariance_type: CovarianceType::Full,
938 tol: 1e-3,
939 max_iter: 100,
940 reg_covar: 1e-6,
941 random_state: None,
942 n_init: 1,
943 }
944 }
945
946 pub fn alpha(mut self, alpha: f64) -> Self {
948 self.alpha = alpha;
949 self
950 }
951
952 pub fn max_components(mut self, max_components: usize) -> Self {
954 self.max_components = max_components;
955 self
956 }
957
958 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
960 self.covariance_type = covariance_type;
961 self
962 }
963
964 pub fn tol(mut self, tol: f64) -> Self {
966 self.tol = tol;
967 self
968 }
969
970 pub fn max_iter(mut self, max_iter: usize) -> Self {
972 self.max_iter = max_iter;
973 self
974 }
975
976 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
978 self.reg_covar = reg_covar;
979 self
980 }
981
982 pub fn random_state(mut self, random_state: u64) -> Self {
984 self.random_state = Some(random_state);
985 self
986 }
987
988 pub fn n_init(mut self, n_init: usize) -> Self {
990 self.n_init = n_init;
991 self
992 }
993}
994
995impl Default for DirichletProcessGaussianMixture<Untrained> {
996 fn default() -> Self {
997 Self::new()
998 }
999}
1000
1001impl Estimator for DirichletProcessGaussianMixture<Untrained> {
1002 type Config = ();
1003 type Error = SklearsError;
1004 type Float = Float;
1005
1006 fn config(&self) -> &Self::Config {
1007 &()
1008 }
1009}
1010
1011impl Fit<ArrayView2<'_, Float>, ()> for DirichletProcessGaussianMixture<Untrained> {
1012 type Fitted = DirichletProcessGaussianMixture<DirichletProcessGaussianMixtureTrained>;
1013
1014 #[allow(non_snake_case)]
1015 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
1016 let X = X.to_owned();
1017 let (n_samples, n_features) = X.dim();
1018
1019 if n_samples == 0 {
1020 return Err(SklearsError::InvalidInput(
1021 "No samples provided".to_string(),
1022 ));
1023 }
1024
1025 let stick_weights = Array1::ones(self.max_components);
1027 let mut weights = Array1::zeros(self.max_components);
1028
1029 let mut remaining = 1.0;
1031 for k in 0..self.max_components {
1032 weights[k] = remaining / (self.max_components - k) as f64;
1033 remaining -= weights[k];
1034 }
1035
1036 let mut rng = match self.random_state {
1038 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1039 None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
1040 };
1041
1042 let mut means = Array2::zeros((self.max_components, n_features));
1043 means.row_mut(0).assign(&X.row(rng.gen_range(0..n_samples)));
1044
1045 for k in 1..self.max_components {
1046 let mut distances = Array1::zeros(n_samples);
1047 for i in 0..n_samples {
1048 let mut min_dist = f64::INFINITY;
1049 for j in 0..k {
1050 let dist = (&X.row(i) - &means.row(j)).mapv(|x| x * x).sum().sqrt();
1051 if dist < min_dist {
1052 min_dist = dist;
1053 }
1054 }
1055 distances[i] = min_dist * min_dist;
1056 }
1057
1058 let total_dist: f64 = distances.sum();
1059 let target = rng.gen::<f64>() * total_dist;
1060 let mut cumulative = 0.0;
1061
1062 for i in 0..n_samples {
1063 cumulative += distances[i];
1064 if cumulative >= target {
1065 means.row_mut(k).assign(&X.row(i));
1066 break;
1067 }
1068 }
1069 }
1070
1071 let sample_cov = self.compute_sample_covariance(&X)?;
1073 let mut covariances = Vec::new();
1074 for _ in 0..self.max_components {
1075 covariances.push(sample_cov.clone());
1076 }
1077
1078 let mut prev_lower_bound = f64::NEG_INFINITY;
1079 let mut converged = false;
1080 let mut n_iter = 0;
1081
1082 for iteration in 0..self.max_iter {
1084 n_iter = iteration + 1;
1085
1086 let responsibilities =
1088 self.compute_responsibilities(&X, &weights, &means, &covariances)?;
1089
1090 weights = self.update_weights(&responsibilities)?;
1092 means = self.update_means(&X, &responsibilities)?;
1093 covariances = self.update_covariances(&X, &responsibilities, &means)?;
1094
1095 let lower_bound =
1097 self.compute_lower_bound(&X, &responsibilities, &weights, &means, &covariances)?;
1098
1099 if (lower_bound - prev_lower_bound).abs() < self.tol {
1100 converged = true;
1101 break;
1102 }
1103 prev_lower_bound = lower_bound;
1104 }
1105
1106 let mut n_components = 0;
1108 for k in 0..self.max_components {
1109 if weights[k] > 1e-3 {
1110 n_components = k + 1;
1111 }
1112 }
1113
1114 Ok(DirichletProcessGaussianMixture {
1115 state: DirichletProcessGaussianMixtureTrained {
1116 weights: weights.slice(s![..n_components]).to_owned(),
1117 means: means.slice(s![..n_components, ..]).to_owned(),
1118 covariances: covariances.into_iter().take(n_components).collect(),
1119 weight_concentration: stick_weights.slice(s![..n_components]).to_owned(),
1120 lower_bound: prev_lower_bound,
1121 n_iter,
1122 converged,
1123 n_components,
1124 n_features,
1125 covariance_type: self.covariance_type.clone(),
1126 alpha: self.alpha,
1127 reg_covar: self.reg_covar,
1128 },
1129 alpha: self.alpha,
1130 max_components: self.max_components,
1131 covariance_type: self.covariance_type,
1132 tol: self.tol,
1133 max_iter: self.max_iter,
1134 reg_covar: self.reg_covar,
1135 random_state: self.random_state,
1136 n_init: self.n_init,
1137 })
1138 }
1139}
1140
1141impl DirichletProcessGaussianMixture<Untrained> {
1142 fn compute_sample_covariance(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
1143 let (n_samples, n_features) = X.dim();
1144 let mean = X.mean_axis(Axis(0)).unwrap();
1145 let mut cov = Array2::zeros((n_features, n_features));
1146
1147 for i in 0..n_samples {
1148 let diff = &X.row(i) - &mean;
1149 for j in 0..n_features {
1150 for k in 0..n_features {
1151 cov[[j, k]] += diff[j] * diff[k];
1152 }
1153 }
1154 }
1155 cov /= (n_samples - 1) as f64;
1156
1157 for i in 0..n_features {
1158 cov[[i, i]] += self.reg_covar;
1159 }
1160
1161 Ok(cov)
1162 }
1163
1164 fn compute_responsibilities(
1165 &self,
1166 X: &Array2<f64>,
1167 weights: &Array1<f64>,
1168 means: &Array2<f64>,
1169 covariances: &[Array2<f64>],
1170 ) -> SklResult<Array2<f64>> {
1171 let (n_samples, _) = X.dim();
1172 let n_components = weights.len();
1173 let mut responsibilities = Array2::zeros((n_samples, n_components));
1174
1175 for i in 0..n_samples {
1176 let sample = X.row(i);
1177 let mut log_probs = Array1::zeros(n_components);
1178
1179 for k in 0..n_components {
1180 let log_weight = weights[k].ln();
1181 let log_pdf =
1182 self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
1183 log_probs[k] = log_weight + log_pdf;
1184 }
1185
1186 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1187 let log_sum_exp = max_log_prob
1188 + log_probs
1189 .iter()
1190 .map(|&lp| (lp - max_log_prob).exp())
1191 .sum::<f64>()
1192 .ln();
1193
1194 for k in 0..n_components {
1195 responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
1196 }
1197 }
1198
1199 Ok(responsibilities)
1200 }
1201
1202 fn update_weights(&self, responsibilities: &Array2<f64>) -> SklResult<Array1<f64>> {
1203 let n_components = responsibilities.dim().1;
1204 let mut weights = Array1::zeros(n_components);
1205
1206 for k in 0..n_components {
1207 weights[k] = responsibilities.column(k).sum() / responsibilities.dim().0 as f64;
1208 }
1209
1210 Ok(weights)
1211 }
1212
1213 fn update_means(
1214 &self,
1215 X: &Array2<f64>,
1216 responsibilities: &Array2<f64>,
1217 ) -> SklResult<Array2<f64>> {
1218 let (n_samples, n_features) = X.dim();
1219 let n_components = responsibilities.dim().1;
1220 let mut means = Array2::zeros((n_components, n_features));
1221
1222 for k in 0..n_components {
1223 let weight_sum = responsibilities.column(k).sum();
1224 if weight_sum > 1e-10 {
1225 for j in 0..n_features {
1226 let mut weighted_sum = 0.0;
1227 for i in 0..n_samples {
1228 weighted_sum += responsibilities[[i, k]] * X[[i, j]];
1229 }
1230 means[[k, j]] = weighted_sum / weight_sum;
1231 }
1232 }
1233 }
1234
1235 Ok(means)
1236 }
1237
1238 fn update_covariances(
1239 &self,
1240 X: &Array2<f64>,
1241 responsibilities: &Array2<f64>,
1242 means: &Array2<f64>,
1243 ) -> SklResult<Vec<Array2<f64>>> {
1244 let (n_samples, n_features) = X.dim();
1245 let n_components = responsibilities.dim().1;
1246 let mut covariances = Vec::new();
1247
1248 for k in 0..n_components {
1249 let weight_sum = responsibilities.column(k).sum();
1250 let mut cov = Array2::zeros((n_features, n_features));
1251
1252 if weight_sum > 1e-10 {
1253 for i in 0..n_samples {
1254 let diff = &X.row(i) - &means.row(k);
1255 let weight = responsibilities[[i, k]];
1256 for j in 0..n_features {
1257 for l in 0..n_features {
1258 cov[[j, l]] += weight * diff[j] * diff[l];
1259 }
1260 }
1261 }
1262 cov /= weight_sum;
1263 } else {
1264 for j in 0..n_features {
1265 cov[[j, j]] = 1.0;
1266 }
1267 }
1268
1269 for j in 0..n_features {
1270 cov[[j, j]] += self.reg_covar;
1271 }
1272
1273 covariances.push(cov);
1274 }
1275
1276 Ok(covariances)
1277 }
1278
1279 fn compute_lower_bound(
1280 &self,
1281 X: &Array2<f64>,
1282 responsibilities: &Array2<f64>,
1283 weights: &Array1<f64>,
1284 means: &Array2<f64>,
1285 covariances: &[Array2<f64>],
1286 ) -> SklResult<f64> {
1287 let (n_samples, _) = X.dim();
1288 let n_components = weights.len();
1289 let mut lower_bound = 0.0;
1290
1291 for i in 0..n_samples {
1293 let sample = X.row(i);
1294 for k in 0..n_components {
1295 if responsibilities[[i, k]] > 1e-10 {
1296 let log_pdf =
1297 self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
1298 lower_bound += responsibilities[[i, k]] * (weights[k].ln() + log_pdf);
1299 }
1300 }
1301 }
1302
1303 for i in 0..n_samples {
1305 for k in 0..n_components {
1306 if responsibilities[[i, k]] > 1e-10 {
1307 lower_bound -= responsibilities[[i, k]] * responsibilities[[i, k]].ln();
1308 }
1309 }
1310 }
1311
1312 Ok(lower_bound)
1313 }
1314
1315 fn multivariate_normal_log_pdf(
1316 &self,
1317 x: &ArrayView1<f64>,
1318 mean: &ArrayView1<f64>,
1319 cov: &Array2<f64>,
1320 ) -> SklResult<f64> {
1321 let d = x.len() as f64;
1322 let diff = x - mean;
1323
1324 let det = self.matrix_determinant(cov)?;
1325 if det <= 0.0 {
1326 return Ok(f64::NEG_INFINITY);
1327 }
1328
1329 let inv_cov = self.matrix_inverse(cov)?;
1330 let mut quad_form = 0.0;
1331 for i in 0..diff.len() {
1332 for j in 0..diff.len() {
1333 quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
1334 }
1335 }
1336
1337 Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
1338 }
1339
1340 fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
1341 let n = A.dim().0;
1342 if n == 1 {
1343 return Ok(A[[0, 0]]);
1344 }
1345 if n == 2 {
1346 return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
1347 }
1348
1349 let mut det = 1.0;
1350 let mut A_copy = A.clone();
1351
1352 for i in 0..n {
1353 let mut max_row = i;
1354 for k in i + 1..n {
1355 if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
1356 max_row = k;
1357 }
1358 }
1359
1360 if max_row != i {
1361 for j in 0..n {
1362 let temp = A_copy[[i, j]];
1363 A_copy[[i, j]] = A_copy[[max_row, j]];
1364 A_copy[[max_row, j]] = temp;
1365 }
1366 det *= -1.0;
1367 }
1368
1369 if A_copy[[i, i]].abs() < 1e-12 {
1370 return Ok(0.0);
1371 }
1372
1373 det *= A_copy[[i, i]];
1374
1375 for k in i + 1..n {
1376 let factor = A_copy[[k, i]] / A_copy[[i, i]];
1377 for j in i..n {
1378 A_copy[[k, j]] -= factor * A_copy[[i, j]];
1379 }
1380 }
1381 }
1382
1383 Ok(det)
1384 }
1385
1386 fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
1387 let n = A.dim().0;
1388 let mut aug = Array2::zeros((n, 2 * n));
1389
1390 for i in 0..n {
1391 for j in 0..n {
1392 aug[[i, j]] = A[[i, j]];
1393 aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
1394 }
1395 }
1396
1397 for i in 0..n {
1398 let mut max_row = i;
1399 for k in i + 1..n {
1400 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
1401 max_row = k;
1402 }
1403 }
1404
1405 if max_row != i {
1406 for j in 0..2 * n {
1407 let temp = aug[[i, j]];
1408 aug[[i, j]] = aug[[max_row, j]];
1409 aug[[max_row, j]] = temp;
1410 }
1411 }
1412
1413 if aug[[i, i]].abs() < 1e-12 {
1414 return Err(SklearsError::NumericalError(
1415 "Matrix is singular".to_string(),
1416 ));
1417 }
1418
1419 let pivot = aug[[i, i]];
1420 for j in 0..2 * n {
1421 aug[[i, j]] /= pivot;
1422 }
1423
1424 for k in 0..n {
1425 if k != i {
1426 let factor = aug[[k, i]];
1427 for j in 0..2 * n {
1428 aug[[k, j]] -= factor * aug[[i, j]];
1429 }
1430 }
1431 }
1432 }
1433
1434 let mut inv = Array2::zeros((n, n));
1435 for i in 0..n {
1436 for j in 0..n {
1437 inv[[i, j]] = aug[[i, j + n]];
1438 }
1439 }
1440
1441 Ok(inv)
1442 }
1443}
1444
1445#[derive(Debug, Clone)]
1447pub struct DirichletProcessGaussianMixtureTrained {
1448 pub weights: Array1<f64>,
1450 pub means: Array2<f64>,
1452 pub covariances: Vec<Array2<f64>>,
1454 pub weight_concentration: Array1<f64>,
1456 pub lower_bound: f64,
1458 pub n_iter: usize,
1460 pub converged: bool,
1462 pub n_components: usize,
1464 pub n_features: usize,
1466 pub covariance_type: CovarianceType,
1468 pub alpha: f64,
1470 pub reg_covar: f64,
1472}
1473
1474impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1475 for DirichletProcessGaussianMixture<DirichletProcessGaussianMixtureTrained>
1476{
1477 #[allow(non_snake_case)]
1478 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1479 let X = X.to_owned();
1480 let (n_samples, _) = X.dim();
1481 let mut predictions = Array1::zeros(n_samples);
1482
1483 for i in 0..n_samples {
1484 let sample = X.row(i);
1485 let mut max_log_prob = f64::NEG_INFINITY;
1486 let mut best_component = 0;
1487
1488 for k in 0..self.state.n_components {
1489 let log_weight = self.state.weights[k].ln();
1490 let log_pdf = self.multivariate_normal_log_pdf(
1491 &sample,
1492 &self.state.means.row(k),
1493 &self.state.covariances[k],
1494 )?;
1495 let log_prob = log_weight + log_pdf;
1496
1497 if log_prob > max_log_prob {
1498 max_log_prob = log_prob;
1499 best_component = k;
1500 }
1501 }
1502
1503 predictions[i] = best_component as i32;
1504 }
1505
1506 Ok(predictions)
1507 }
1508}
1509
1510impl DirichletProcessGaussianMixture<DirichletProcessGaussianMixtureTrained> {
1511 #[allow(non_snake_case)]
1513 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1514 let X = X.to_owned();
1515 let (n_samples, _) = X.dim();
1516 let mut probabilities = Array2::zeros((n_samples, self.state.n_components));
1517
1518 for i in 0..n_samples {
1519 let sample = X.row(i);
1520 let mut log_probs = Array1::zeros(self.state.n_components);
1521
1522 for k in 0..self.state.n_components {
1523 let log_weight = self.state.weights[k].ln();
1524 let log_pdf = self.multivariate_normal_log_pdf(
1525 &sample,
1526 &self.state.means.row(k),
1527 &self.state.covariances[k],
1528 )?;
1529 log_probs[k] = log_weight + log_pdf;
1530 }
1531
1532 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1533 let log_sum_exp = max_log_prob
1534 + log_probs
1535 .iter()
1536 .map(|&lp| (lp - max_log_prob).exp())
1537 .sum::<f64>()
1538 .ln();
1539
1540 for k in 0..self.state.n_components {
1541 probabilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
1542 }
1543 }
1544
1545 Ok(probabilities)
1546 }
1547
1548 pub fn score(&self, _X: &ArrayView2<'_, Float>) -> SklResult<f64> {
1550 Ok(self.state.lower_bound)
1551 }
1552
1553 fn multivariate_normal_log_pdf(
1554 &self,
1555 x: &ArrayView1<f64>,
1556 mean: &ArrayView1<f64>,
1557 cov: &Array2<f64>,
1558 ) -> SklResult<f64> {
1559 let d = x.len() as f64;
1560 let diff = x - mean;
1561
1562 let det = self.matrix_determinant(cov)?;
1563 if det <= 0.0 {
1564 return Ok(f64::NEG_INFINITY);
1565 }
1566
1567 let inv_cov = self.matrix_inverse(cov)?;
1568 let mut quad_form = 0.0;
1569 for i in 0..diff.len() {
1570 for j in 0..diff.len() {
1571 quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
1572 }
1573 }
1574
1575 Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
1576 }
1577
1578 fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
1579 let n = A.dim().0;
1580 if n == 1 {
1581 return Ok(A[[0, 0]]);
1582 }
1583 if n == 2 {
1584 return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
1585 }
1586
1587 let mut det = 1.0;
1588 let mut A_copy = A.clone();
1589
1590 for i in 0..n {
1591 let mut max_row = i;
1592 for k in i + 1..n {
1593 if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
1594 max_row = k;
1595 }
1596 }
1597
1598 if max_row != i {
1599 for j in 0..n {
1600 let temp = A_copy[[i, j]];
1601 A_copy[[i, j]] = A_copy[[max_row, j]];
1602 A_copy[[max_row, j]] = temp;
1603 }
1604 det *= -1.0;
1605 }
1606
1607 if A_copy[[i, i]].abs() < 1e-12 {
1608 return Ok(0.0);
1609 }
1610
1611 det *= A_copy[[i, i]];
1612
1613 for k in i + 1..n {
1614 let factor = A_copy[[k, i]] / A_copy[[i, i]];
1615 for j in i..n {
1616 A_copy[[k, j]] -= factor * A_copy[[i, j]];
1617 }
1618 }
1619 }
1620
1621 Ok(det)
1622 }
1623
1624 fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
1625 let n = A.dim().0;
1626 let mut aug = Array2::zeros((n, 2 * n));
1627
1628 for i in 0..n {
1629 for j in 0..n {
1630 aug[[i, j]] = A[[i, j]];
1631 aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
1632 }
1633 }
1634
1635 for i in 0..n {
1636 let mut max_row = i;
1637 for k in i + 1..n {
1638 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
1639 max_row = k;
1640 }
1641 }
1642
1643 if max_row != i {
1644 for j in 0..2 * n {
1645 let temp = aug[[i, j]];
1646 aug[[i, j]] = aug[[max_row, j]];
1647 aug[[max_row, j]] = temp;
1648 }
1649 }
1650
1651 if aug[[i, i]].abs() < 1e-12 {
1652 return Err(SklearsError::NumericalError(
1653 "Matrix is singular".to_string(),
1654 ));
1655 }
1656
1657 let pivot = aug[[i, i]];
1658 for j in 0..2 * n {
1659 aug[[i, j]] /= pivot;
1660 }
1661
1662 for k in 0..n {
1663 if k != i {
1664 let factor = aug[[k, i]];
1665 for j in 0..2 * n {
1666 aug[[k, j]] -= factor * aug[[i, j]];
1667 }
1668 }
1669 }
1670 }
1671
1672 let mut inv = Array2::zeros((n, n));
1673 for i in 0..n {
1674 for j in 0..n {
1675 inv[[i, j]] = aug[[i, j + n]];
1676 }
1677 }
1678
1679 Ok(inv)
1680 }
1681}
1682
1683