sklears_mixture/
nonparametric.rs

1//! Nonparametric Mixture Models
2//!
3//! This module provides nonparametric Bayesian mixture models that automatically
4//! determine the number of components from the data. It includes implementations
5//! of the Chinese Restaurant Process, Dirichlet Process Gaussian Mixture Model,
6//! Pitman-Yor Process, and Hierarchical Dirichlet Process.
7
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::{Rng, SeedableRng};
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Predict, Untrained},
13    types::Float,
14};
15use std::f64::consts::PI;
16
17use crate::common::CovarianceType;
18
19/// Utility function for log-sum-exp
20fn log_sum_exp(a: f64, b: f64) -> f64 {
21    let max_val = a.max(b);
22    if max_val.is_finite() {
23        max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
24    } else {
25        max_val
26    }
27}
28
29/// Chinese Restaurant Process Mixture Model
30///
31/// A nonparametric Bayesian mixture model that automatically determines the number of components
32/// using the Chinese Restaurant Process metaphor. Customers (data points) sit at tables (components)
33/// with probability proportional to the number of existing customers at each table, or start a new
34/// table with probability proportional to the concentration parameter α.
35///
36/// # Examples
37///
38/// ```
39/// use sklears_mixture::{ChineseRestaurantProcess, CovarianceType};
40/// use sklears_core::traits::{Predict, Fit};
41/// use scirs2_core::ndarray::array;
42///
43/// let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [10.0, 10.0], [11.0, 11.0], [12.0, 12.0]];
44///
45/// let model = ChineseRestaurantProcess::new()
46///     .alpha(1.0)
47///     .max_components(10)
48///     .covariance_type(CovarianceType::Diagonal);
49/// let fitted = model.fit(&X.view(), &()).unwrap();
50/// let labels = fitted.predict(&X.view()).unwrap();
51/// ```
52#[derive(Debug, Clone)]
53pub struct ChineseRestaurantProcess<S = Untrained> {
54    state: S,
55    /// Concentration parameter (higher values favor more clusters)
56    alpha: f64,
57    /// Maximum number of components allowed
58    max_components: usize,
59    /// Covariance type for the Gaussian base distribution
60    covariance_type: CovarianceType,
61    /// Convergence tolerance
62    tol: f64,
63    /// Maximum number of iterations
64    max_iter: usize,
65    /// Random state for reproducibility
66    random_state: Option<u64>,
67    /// Regularization parameter for covariance matrices
68    reg_covar: f64,
69}
70
71impl ChineseRestaurantProcess<Untrained> {
72    /// Create a new Chinese Restaurant Process
73    pub fn new() -> Self {
74        Self {
75            state: Untrained,
76            alpha: 1.0,
77            max_components: 20,
78            covariance_type: CovarianceType::Full,
79            tol: 1e-3,
80            max_iter: 100,
81            random_state: None,
82            reg_covar: 1e-6,
83        }
84    }
85
86    /// Set the concentration parameter
87    pub fn alpha(mut self, alpha: f64) -> Self {
88        self.alpha = alpha;
89        self
90    }
91
92    /// Set the maximum number of components
93    pub fn max_components(mut self, max_components: usize) -> Self {
94        self.max_components = max_components;
95        self
96    }
97
98    /// Set the covariance type
99    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
100        self.covariance_type = covariance_type;
101        self
102    }
103
104    /// Set the convergence tolerance
105    pub fn tol(mut self, tol: f64) -> Self {
106        self.tol = tol;
107        self
108    }
109
110    /// Set the maximum number of iterations
111    pub fn max_iter(mut self, max_iter: usize) -> Self {
112        self.max_iter = max_iter;
113        self
114    }
115
116    /// Set the random state
117    pub fn random_state(mut self, random_state: u64) -> Self {
118        self.random_state = Some(random_state);
119        self
120    }
121
122    /// Set the regularization parameter
123    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
124        self.reg_covar = reg_covar;
125        self
126    }
127
128    /// Initialize tables and assignments using random clustering
129    fn initialize_tables(
130        &self,
131        X: &Array2<f64>,
132        rng: &mut scirs2_core::random::rngs::StdRng,
133    ) -> SklResult<(Array1<usize>, Array1<usize>, usize)> {
134        let n_samples = X.nrows();
135        let mut table_assignments = Array1::zeros(n_samples);
136        let mut table_counts: Array1<usize> = Array1::zeros(self.max_components);
137        let mut n_tables = 1;
138
139        // Assign first customer to first table
140        table_assignments[0] = 0;
141        table_counts[0] = 1;
142
143        // Process remaining customers
144        for i in 1..n_samples {
145            let total_customers = i;
146            let mut probabilities = Array1::zeros(n_tables + 1);
147
148            // Probability of sitting at existing tables
149            for k in 0..n_tables {
150                probabilities[k] = table_counts[k] as f64 / (total_customers as f64 + self.alpha);
151            }
152
153            // Probability of creating new table
154            if n_tables < self.max_components {
155                probabilities[n_tables] = self.alpha / (total_customers as f64 + self.alpha);
156            }
157
158            // Sample table assignment
159            let cumsum: f64 = probabilities.iter().sum();
160            let mut cumulative = 0.0;
161            let target = rng.gen::<f64>() * cumsum;
162
163            let mut chosen_table = 0;
164            for k in 0..=n_tables {
165                cumulative += probabilities[k];
166                if target <= cumulative {
167                    chosen_table = k;
168                    break;
169                }
170            }
171
172            // Update assignments and counts
173            table_assignments[i] = chosen_table;
174            table_counts[chosen_table] += 1;
175
176            // If new table was chosen, increment table count
177            if chosen_table == n_tables && n_tables < self.max_components {
178                n_tables += 1;
179            }
180        }
181
182        Ok((table_assignments, table_counts, n_tables))
183    }
184
185    /// Compute component parameters from assignments
186    fn compute_parameters(
187        &self,
188        X: &Array2<f64>,
189        table_assignments: &Array1<usize>,
190        n_tables: usize,
191    ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
192        let (n_samples, n_features) = X.dim();
193
194        // Compute weights (normalized table sizes)
195        let mut weights = Array1::zeros(n_tables);
196        for &assignment in table_assignments.iter() {
197            if assignment < n_tables {
198                weights[assignment] += 1.0;
199            }
200        }
201        weights /= n_samples as f64;
202
203        // Compute means
204        let mut means = Array2::zeros((n_tables, n_features));
205        let mut counts: Array1<f64> = Array1::zeros(n_tables);
206
207        for (i, &assignment) in table_assignments.iter().enumerate() {
208            if assignment < n_tables {
209                for j in 0..n_features {
210                    means[[assignment, j]] += X[[i, j]];
211                }
212                counts[assignment] += 1.0;
213            }
214        }
215
216        // Normalize means
217        for k in 0..n_tables {
218            if counts[k] > 0.0 {
219                for j in 0..n_features {
220                    means[[k, j]] /= counts[k];
221                }
222            }
223        }
224
225        // Compute covariances
226        let mut covariances = Vec::new();
227        for k in 0..n_tables {
228            let mut cov = Array2::zeros((n_features, n_features));
229            let mut count = 0.0;
230
231            for (i, &assignment) in table_assignments.iter().enumerate() {
232                if assignment == k {
233                    let diff = &X.row(i) - &means.row(k);
234                    for j in 0..n_features {
235                        for l in 0..n_features {
236                            cov[[j, l]] += diff[j] * diff[l];
237                        }
238                    }
239                    count += 1.0;
240                }
241            }
242
243            if count > 1.0 {
244                cov /= count - 1.0;
245            } else {
246                // Use identity for single-point clusters
247                for j in 0..n_features {
248                    cov[[j, j]] = 1.0;
249                }
250            }
251
252            // Apply covariance type constraints and regularization
253            cov = self.regularize_covariance(cov)?;
254            covariances.push(cov);
255        }
256
257        Ok((weights, means, covariances))
258    }
259
260    /// Regularize covariance matrix based on covariance type
261    fn regularize_covariance(&self, mut cov: Array2<f64>) -> SklResult<Array2<f64>> {
262        let n_features = cov.dim().0;
263
264        match self.covariance_type {
265            CovarianceType::Full => {
266                // Add regularization to diagonal
267                for i in 0..n_features {
268                    cov[[i, i]] += self.reg_covar;
269                }
270            }
271            CovarianceType::Diagonal => {
272                // Keep only diagonal elements
273                for i in 0..n_features {
274                    for j in 0..n_features {
275                        if i != j {
276                            cov[[i, j]] = 0.0;
277                        } else {
278                            cov[[i, i]] += self.reg_covar;
279                        }
280                    }
281                }
282            }
283            CovarianceType::Tied => {
284                // Use full covariance (tied across components handled at higher level)
285                for i in 0..n_features {
286                    cov[[i, i]] += self.reg_covar;
287                }
288            }
289            CovarianceType::Spherical => {
290                // Use scalar variance
291                let trace = (0..n_features).map(|i| cov[[i, i]]).sum::<f64>() / n_features as f64;
292                cov.fill(0.0);
293                for i in 0..n_features {
294                    cov[[i, i]] = trace + self.reg_covar;
295                }
296            }
297        }
298
299        Ok(cov)
300    }
301
302    /// Compute log likelihood
303    fn compute_log_likelihood(
304        &self,
305        X: &Array2<f64>,
306        weights: &Array1<f64>,
307        means: &Array2<f64>,
308        covariances: &[Array2<f64>],
309    ) -> SklResult<f64> {
310        let n_samples = X.nrows();
311        let n_components = weights.len();
312        let mut log_likelihood = 0.0;
313
314        for i in 0..n_samples {
315            let sample = X.row(i);
316            let mut sample_likelihood = 0.0;
317
318            for k in 0..n_components {
319                if weights[k] > 1e-10 {
320                    let log_pdf =
321                        self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
322                    sample_likelihood += weights[k] * log_pdf.exp();
323                }
324            }
325
326            if sample_likelihood > 1e-300 {
327                log_likelihood += sample_likelihood.ln();
328            }
329        }
330
331        Ok(log_likelihood)
332    }
333
334    /// Compute multivariate normal log PDF
335    fn multivariate_normal_log_pdf(
336        &self,
337        x: &ArrayView1<f64>,
338        mean: &ArrayView1<f64>,
339        cov: &Array2<f64>,
340    ) -> SklResult<f64> {
341        let d = x.len() as f64;
342        let diff = x - mean;
343
344        // Simple determinant and inverse computation for small matrices
345        let det = self.matrix_determinant(cov)?;
346        if det <= 0.0 {
347            return Ok(f64::NEG_INFINITY);
348        }
349
350        let inv_cov = self.matrix_inverse(cov)?;
351        let mut quad_form = 0.0;
352        for i in 0..diff.len() {
353            for j in 0..diff.len() {
354                quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
355            }
356        }
357
358        Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
359    }
360
361    /// Simple matrix determinant calculation
362    fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
363        let n = A.dim().0;
364        if n == 1 {
365            return Ok(A[[0, 0]]);
366        }
367        if n == 2 {
368            return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
369        }
370
371        // For larger matrices, use simple LU decomposition
372        let mut det = 1.0;
373        let mut A_copy = A.clone();
374
375        for i in 0..n {
376            let mut max_row = i;
377            for k in i + 1..n {
378                if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
379                    max_row = k;
380                }
381            }
382
383            if max_row != i {
384                for j in 0..n {
385                    let temp = A_copy[[i, j]];
386                    A_copy[[i, j]] = A_copy[[max_row, j]];
387                    A_copy[[max_row, j]] = temp;
388                }
389                det *= -1.0;
390            }
391
392            if A_copy[[i, i]].abs() < 1e-12 {
393                return Ok(0.0);
394            }
395
396            det *= A_copy[[i, i]];
397
398            for k in i + 1..n {
399                let factor = A_copy[[k, i]] / A_copy[[i, i]];
400                for j in i..n {
401                    A_copy[[k, j]] -= factor * A_copy[[i, j]];
402                }
403            }
404        }
405
406        Ok(det)
407    }
408
409    /// Simple matrix inverse using Gauss-Jordan elimination
410    fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
411        let n = A.dim().0;
412        let mut aug = Array2::zeros((n, 2 * n));
413
414        // Create augmented matrix [A | I]
415        for i in 0..n {
416            for j in 0..n {
417                aug[[i, j]] = A[[i, j]];
418                aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
419            }
420        }
421
422        // Gauss-Jordan elimination
423        for i in 0..n {
424            // Find pivot
425            let mut max_row = i;
426            for k in i + 1..n {
427                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
428                    max_row = k;
429                }
430            }
431
432            // Swap rows
433            if max_row != i {
434                for j in 0..2 * n {
435                    let temp = aug[[i, j]];
436                    aug[[i, j]] = aug[[max_row, j]];
437                    aug[[max_row, j]] = temp;
438                }
439            }
440
441            // Check for singular matrix
442            if aug[[i, i]].abs() < 1e-12 {
443                return Err(SklearsError::NumericalError(
444                    "Matrix is singular".to_string(),
445                ));
446            }
447
448            // Scale pivot row
449            let pivot = aug[[i, i]];
450            for j in 0..2 * n {
451                aug[[i, j]] /= pivot;
452            }
453
454            // Eliminate column
455            for k in 0..n {
456                if k != i {
457                    let factor = aug[[k, i]];
458                    for j in 0..2 * n {
459                        aug[[k, j]] -= factor * aug[[i, j]];
460                    }
461                }
462            }
463        }
464
465        // Extract inverse matrix
466        let mut inv = Array2::zeros((n, n));
467        for i in 0..n {
468            for j in 0..n {
469                inv[[i, j]] = aug[[i, j + n]];
470            }
471        }
472
473        Ok(inv)
474    }
475}
476
477impl Default for ChineseRestaurantProcess<Untrained> {
478    fn default() -> Self {
479        Self::new()
480    }
481}
482
483impl Estimator for ChineseRestaurantProcess<Untrained> {
484    type Config = ();
485    type Error = SklearsError;
486    type Float = Float;
487
488    fn config(&self) -> &Self::Config {
489        &()
490    }
491}
492
493impl Fit<ArrayView2<'_, Float>, ()> for ChineseRestaurantProcess<Untrained> {
494    type Fitted = ChineseRestaurantProcess<ChineseRestaurantProcessTrained>;
495
496    #[allow(non_snake_case)]
497    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
498        let X = X.to_owned();
499        let (n_samples, n_features) = X.dim();
500
501        if n_samples == 0 {
502            return Err(SklearsError::InvalidInput(
503                "No samples provided".to_string(),
504            ));
505        }
506
507        let mut rng = match self.random_state {
508            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
509            None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
510        };
511
512        // Initialize table assignments
513        let (mut table_assignments, mut table_counts, mut n_tables) =
514            self.initialize_tables(&X, &mut rng)?;
515
516        let mut prev_log_likelihood = f64::NEG_INFINITY;
517        let mut converged = false;
518        let mut n_iter = 0;
519
520        // Gibbs sampling iterations
521        for iteration in 0..self.max_iter {
522            n_iter = iteration + 1;
523
524            // Update table assignments for each customer
525            for i in 0..n_samples {
526                // Remove customer from current table
527                let current_table = table_assignments[i];
528                table_counts[current_table] -= 1;
529
530                // If table becomes empty, remove it
531                if table_counts[current_table] == 0 && current_table == n_tables - 1 {
532                    n_tables -= 1;
533                }
534
535                // Compute probabilities for each table
536                let mut probabilities = Array1::zeros(n_tables + 1);
537                let remaining_customers = n_samples - 1;
538
539                // Existing tables
540                for k in 0..n_tables {
541                    if table_counts[k] > 0 {
542                        probabilities[k] =
543                            table_counts[k] as f64 / (remaining_customers as f64 + self.alpha);
544                    }
545                }
546
547                // New table
548                if n_tables < self.max_components {
549                    probabilities[n_tables] =
550                        self.alpha / (remaining_customers as f64 + self.alpha);
551                }
552
553                // Sample new assignment
554                let cumsum: f64 = probabilities.iter().sum();
555                let mut cumulative = 0.0;
556                let target = rng.gen::<f64>() * cumsum;
557
558                let mut new_table = 0;
559                for k in 0..=n_tables {
560                    cumulative += probabilities[k];
561                    if target <= cumulative {
562                        new_table = k;
563                        break;
564                    }
565                }
566
567                // Update assignments
568                table_assignments[i] = new_table;
569                table_counts[new_table] += 1;
570
571                // If new table was chosen, increment table count
572                if new_table == n_tables && n_tables < self.max_components {
573                    n_tables += 1;
574                }
575            }
576
577            // Compute parameters and log-likelihood every few iterations
578            if iteration % 5 == 0 {
579                let (weights, means, covariances) =
580                    self.compute_parameters(&X, &table_assignments, n_tables)?;
581                let log_likelihood =
582                    self.compute_log_likelihood(&X, &weights, &means, &covariances)?;
583
584                if (log_likelihood - prev_log_likelihood).abs() < self.tol {
585                    converged = true;
586                    break;
587                }
588                prev_log_likelihood = log_likelihood;
589            }
590        }
591
592        // Final parameter computation
593        let (weights, means, covariances) =
594            self.compute_parameters(&X, &table_assignments, n_tables)?;
595        let log_likelihood = self.compute_log_likelihood(&X, &weights, &means, &covariances)?;
596
597        Ok(ChineseRestaurantProcess {
598            state: ChineseRestaurantProcessTrained {
599                n_components: n_tables,
600                weights,
601                means,
602                covariances,
603                covariance_type: self.covariance_type.clone(),
604                n_features,
605                alpha: self.alpha,
606                table_assignments,
607                table_counts: table_counts.slice(s![..n_tables]).to_owned(),
608                log_likelihood,
609                n_iter,
610                converged,
611                reg_covar: self.reg_covar,
612            },
613            alpha: self.alpha,
614            max_components: self.max_components,
615            covariance_type: self.covariance_type,
616            tol: self.tol,
617            max_iter: self.max_iter,
618            random_state: self.random_state,
619            reg_covar: self.reg_covar,
620        })
621    }
622}
623
624/// Trained Chinese Restaurant Process Mixture Model
625#[derive(Debug, Clone)]
626pub struct ChineseRestaurantProcessTrained {
627    /// Number of active components
628    pub n_components: usize,
629    /// Mixture weights (table sizes normalized)
630    pub weights: Array1<f64>,
631    /// Component means
632    pub means: Array2<f64>,
633    /// Component covariances
634    pub covariances: Vec<Array2<f64>>,
635    /// Covariance type
636    pub covariance_type: CovarianceType,
637    /// Number of features
638    pub n_features: usize,
639    /// Concentration parameter
640    pub alpha: f64,
641    /// Table assignments for training data
642    pub table_assignments: Array1<usize>,
643    /// Number of customers at each table
644    pub table_counts: Array1<usize>,
645    /// Log-likelihood of the model
646    pub log_likelihood: f64,
647    /// Number of iterations until convergence
648    pub n_iter: usize,
649    /// Whether the algorithm converged
650    pub converged: bool,
651    /// Regularization parameter
652    pub reg_covar: f64,
653}
654
655impl Predict<ArrayView2<'_, Float>, Array1<i32>>
656    for ChineseRestaurantProcess<ChineseRestaurantProcessTrained>
657{
658    #[allow(non_snake_case)]
659    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
660        let X = X.to_owned();
661        let (n_samples, _) = X.dim();
662        let mut predictions = Array1::zeros(n_samples);
663
664        for i in 0..n_samples {
665            let sample = X.row(i);
666            let mut max_log_prob = f64::NEG_INFINITY;
667            let mut best_component = 0;
668
669            for k in 0..self.state.n_components {
670                let log_weight = self.state.weights[k].ln();
671                let log_pdf = self.multivariate_normal_log_pdf(
672                    &sample,
673                    &self.state.means.row(k),
674                    &self.state.covariances[k],
675                )?;
676                let log_prob = log_weight + log_pdf;
677
678                if log_prob > max_log_prob {
679                    max_log_prob = log_prob;
680                    best_component = k;
681                }
682            }
683
684            predictions[i] = best_component as i32;
685        }
686
687        Ok(predictions)
688    }
689}
690
691impl ChineseRestaurantProcess<ChineseRestaurantProcessTrained> {
692    /// Predict class probabilities
693    #[allow(non_snake_case)]
694    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
695        let X = X.to_owned();
696        let (n_samples, _) = X.dim();
697        let mut probabilities = Array2::zeros((n_samples, self.state.n_components));
698
699        for i in 0..n_samples {
700            let sample = X.row(i);
701            let mut log_probs = Array1::zeros(self.state.n_components);
702
703            for k in 0..self.state.n_components {
704                let log_weight = self.state.weights[k].ln();
705                let log_pdf = self.multivariate_normal_log_pdf(
706                    &sample,
707                    &self.state.means.row(k),
708                    &self.state.covariances[k],
709                )?;
710                log_probs[k] = log_weight + log_pdf;
711            }
712
713            // Numerically stable normalization
714            let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
715            let log_sum_exp = max_log_prob
716                + log_probs
717                    .iter()
718                    .map(|&lp| (lp - max_log_prob).exp())
719                    .sum::<f64>()
720                    .ln();
721
722            for k in 0..self.state.n_components {
723                probabilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
724            }
725        }
726
727        Ok(probabilities)
728    }
729
730    /// Score samples using log-likelihood
731    #[allow(non_snake_case)]
732    pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
733        let X = X.to_owned();
734        self.compute_log_likelihood(
735            &X,
736            &self.state.weights,
737            &self.state.means,
738            &self.state.covariances,
739        )
740    }
741
742    /// Helper method for log PDF computation
743    fn multivariate_normal_log_pdf(
744        &self,
745        x: &ArrayView1<f64>,
746        mean: &ArrayView1<f64>,
747        cov: &Array2<f64>,
748    ) -> SklResult<f64> {
749        let d = x.len() as f64;
750        let diff = x - mean;
751
752        let det = self.matrix_determinant(cov)?;
753        if det <= 0.0 {
754            return Ok(f64::NEG_INFINITY);
755        }
756
757        let inv_cov = self.matrix_inverse(cov)?;
758        let mut quad_form = 0.0;
759        for i in 0..diff.len() {
760            for j in 0..diff.len() {
761                quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
762            }
763        }
764
765        Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
766    }
767
768    /// Helper methods (duplicated for the trained state)
769    fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
770        let n = A.dim().0;
771        if n == 1 {
772            return Ok(A[[0, 0]]);
773        }
774        if n == 2 {
775            return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
776        }
777
778        let mut det = 1.0;
779        let mut A_copy = A.clone();
780
781        for i in 0..n {
782            let mut max_row = i;
783            for k in i + 1..n {
784                if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
785                    max_row = k;
786                }
787            }
788
789            if max_row != i {
790                for j in 0..n {
791                    let temp = A_copy[[i, j]];
792                    A_copy[[i, j]] = A_copy[[max_row, j]];
793                    A_copy[[max_row, j]] = temp;
794                }
795                det *= -1.0;
796            }
797
798            if A_copy[[i, i]].abs() < 1e-12 {
799                return Ok(0.0);
800            }
801
802            det *= A_copy[[i, i]];
803
804            for k in i + 1..n {
805                let factor = A_copy[[k, i]] / A_copy[[i, i]];
806                for j in i..n {
807                    A_copy[[k, j]] -= factor * A_copy[[i, j]];
808                }
809            }
810        }
811
812        Ok(det)
813    }
814
815    fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
816        let n = A.dim().0;
817        let mut aug = Array2::zeros((n, 2 * n));
818
819        for i in 0..n {
820            for j in 0..n {
821                aug[[i, j]] = A[[i, j]];
822                aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
823            }
824        }
825
826        for i in 0..n {
827            let mut max_row = i;
828            for k in i + 1..n {
829                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
830                    max_row = k;
831                }
832            }
833
834            if max_row != i {
835                for j in 0..2 * n {
836                    let temp = aug[[i, j]];
837                    aug[[i, j]] = aug[[max_row, j]];
838                    aug[[max_row, j]] = temp;
839                }
840            }
841
842            if aug[[i, i]].abs() < 1e-12 {
843                return Err(SklearsError::NumericalError(
844                    "Matrix is singular".to_string(),
845                ));
846            }
847
848            let pivot = aug[[i, i]];
849            for j in 0..2 * n {
850                aug[[i, j]] /= pivot;
851            }
852
853            for k in 0..n {
854                if k != i {
855                    let factor = aug[[k, i]];
856                    for j in 0..2 * n {
857                        aug[[k, j]] -= factor * aug[[i, j]];
858                    }
859                }
860            }
861        }
862
863        let mut inv = Array2::zeros((n, n));
864        for i in 0..n {
865            for j in 0..n {
866                inv[[i, j]] = aug[[i, j + n]];
867            }
868        }
869
870        Ok(inv)
871    }
872
873    fn compute_log_likelihood(
874        &self,
875        X: &Array2<f64>,
876        weights: &Array1<f64>,
877        means: &Array2<f64>,
878        covariances: &[Array2<f64>],
879    ) -> SklResult<f64> {
880        let n_samples = X.nrows();
881        let n_components = weights.len();
882        let mut log_likelihood = 0.0;
883
884        for i in 0..n_samples {
885            let sample = X.row(i);
886            let mut sample_likelihood = 0.0;
887
888            for k in 0..n_components {
889                if weights[k] > 1e-10 {
890                    let log_pdf =
891                        self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
892                    sample_likelihood += weights[k] * log_pdf.exp();
893                }
894            }
895
896            if sample_likelihood > 1e-300 {
897                log_likelihood += sample_likelihood.ln();
898            }
899        }
900
901        Ok(log_likelihood)
902    }
903}
904
905/// Dirichlet Process Gaussian Mixture Model
906///
907/// A nonparametric Bayesian mixture model that uses the Dirichlet process as a prior
908/// over the mixture weights. Uses stick-breaking construction and variational inference.
909#[derive(Debug, Clone)]
910pub struct DirichletProcessGaussianMixture<S = Untrained> {
911    state: S,
912    /// Concentration parameter of the Dirichlet process
913    pub alpha: f64,
914    /// Maximum number of components to consider
915    pub max_components: usize,
916    /// Type of covariance parameters
917    pub covariance_type: CovarianceType,
918    /// Convergence threshold
919    pub tol: f64,
920    /// Maximum number of iterations
921    pub max_iter: usize,
922    /// Regularization added to diagonal of covariance
923    pub reg_covar: f64,
924    /// Random state for reproducible results
925    pub random_state: Option<u64>,
926    /// Number of random initializations
927    pub n_init: usize,
928}
929
930impl DirichletProcessGaussianMixture<Untrained> {
931    /// Create a new Dirichlet Process Gaussian Mixture Model
932    pub fn new() -> Self {
933        Self {
934            state: Untrained,
935            alpha: 1.0,
936            max_components: 20,
937            covariance_type: CovarianceType::Full,
938            tol: 1e-3,
939            max_iter: 100,
940            reg_covar: 1e-6,
941            random_state: None,
942            n_init: 1,
943        }
944    }
945
946    /// Set the concentration parameter
947    pub fn alpha(mut self, alpha: f64) -> Self {
948        self.alpha = alpha;
949        self
950    }
951
952    /// Set the maximum number of components
953    pub fn max_components(mut self, max_components: usize) -> Self {
954        self.max_components = max_components;
955        self
956    }
957
958    /// Set the covariance type
959    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
960        self.covariance_type = covariance_type;
961        self
962    }
963
964    /// Set the convergence threshold
965    pub fn tol(mut self, tol: f64) -> Self {
966        self.tol = tol;
967        self
968    }
969
970    /// Set the maximum number of iterations
971    pub fn max_iter(mut self, max_iter: usize) -> Self {
972        self.max_iter = max_iter;
973        self
974    }
975
976    /// Set the regularization parameter
977    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
978        self.reg_covar = reg_covar;
979        self
980    }
981
982    /// Set the random state
983    pub fn random_state(mut self, random_state: u64) -> Self {
984        self.random_state = Some(random_state);
985        self
986    }
987
988    /// Set the number of initializations
989    pub fn n_init(mut self, n_init: usize) -> Self {
990        self.n_init = n_init;
991        self
992    }
993}
994
995impl Default for DirichletProcessGaussianMixture<Untrained> {
996    fn default() -> Self {
997        Self::new()
998    }
999}
1000
1001impl Estimator for DirichletProcessGaussianMixture<Untrained> {
1002    type Config = ();
1003    type Error = SklearsError;
1004    type Float = Float;
1005
1006    fn config(&self) -> &Self::Config {
1007        &()
1008    }
1009}
1010
1011impl Fit<ArrayView2<'_, Float>, ()> for DirichletProcessGaussianMixture<Untrained> {
1012    type Fitted = DirichletProcessGaussianMixture<DirichletProcessGaussianMixtureTrained>;
1013
1014    #[allow(non_snake_case)]
1015    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
1016        let X = X.to_owned();
1017        let (n_samples, n_features) = X.dim();
1018
1019        if n_samples == 0 {
1020            return Err(SklearsError::InvalidInput(
1021                "No samples provided".to_string(),
1022            ));
1023        }
1024
1025        // Initialize stick-breaking weights
1026        let stick_weights = Array1::ones(self.max_components);
1027        let mut weights = Array1::zeros(self.max_components);
1028
1029        // Stick-breaking construction
1030        let mut remaining = 1.0;
1031        for k in 0..self.max_components {
1032            weights[k] = remaining / (self.max_components - k) as f64;
1033            remaining -= weights[k];
1034        }
1035
1036        // Simple K-means++ initialization for means
1037        let mut rng = match self.random_state {
1038            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
1039            None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
1040        };
1041
1042        let mut means = Array2::zeros((self.max_components, n_features));
1043        means.row_mut(0).assign(&X.row(rng.gen_range(0..n_samples)));
1044
1045        for k in 1..self.max_components {
1046            let mut distances = Array1::zeros(n_samples);
1047            for i in 0..n_samples {
1048                let mut min_dist = f64::INFINITY;
1049                for j in 0..k {
1050                    let dist = (&X.row(i) - &means.row(j)).mapv(|x| x * x).sum().sqrt();
1051                    if dist < min_dist {
1052                        min_dist = dist;
1053                    }
1054                }
1055                distances[i] = min_dist * min_dist;
1056            }
1057
1058            let total_dist: f64 = distances.sum();
1059            let target = rng.gen::<f64>() * total_dist;
1060            let mut cumulative = 0.0;
1061
1062            for i in 0..n_samples {
1063                cumulative += distances[i];
1064                if cumulative >= target {
1065                    means.row_mut(k).assign(&X.row(i));
1066                    break;
1067                }
1068            }
1069        }
1070
1071        // Initialize covariances
1072        let sample_cov = self.compute_sample_covariance(&X)?;
1073        let mut covariances = Vec::new();
1074        for _ in 0..self.max_components {
1075            covariances.push(sample_cov.clone());
1076        }
1077
1078        let mut prev_lower_bound = f64::NEG_INFINITY;
1079        let mut converged = false;
1080        let mut n_iter = 0;
1081
1082        // Variational EM iterations
1083        for iteration in 0..self.max_iter {
1084            n_iter = iteration + 1;
1085
1086            // E-step: compute responsibilities
1087            let responsibilities =
1088                self.compute_responsibilities(&X, &weights, &means, &covariances)?;
1089
1090            // M-step: update parameters
1091            weights = self.update_weights(&responsibilities)?;
1092            means = self.update_means(&X, &responsibilities)?;
1093            covariances = self.update_covariances(&X, &responsibilities, &means)?;
1094
1095            // Compute lower bound
1096            let lower_bound =
1097                self.compute_lower_bound(&X, &responsibilities, &weights, &means, &covariances)?;
1098
1099            if (lower_bound - prev_lower_bound).abs() < self.tol {
1100                converged = true;
1101                break;
1102            }
1103            prev_lower_bound = lower_bound;
1104        }
1105
1106        // Determine effective number of components
1107        let mut n_components = 0;
1108        for k in 0..self.max_components {
1109            if weights[k] > 1e-3 {
1110                n_components = k + 1;
1111            }
1112        }
1113
1114        Ok(DirichletProcessGaussianMixture {
1115            state: DirichletProcessGaussianMixtureTrained {
1116                weights: weights.slice(s![..n_components]).to_owned(),
1117                means: means.slice(s![..n_components, ..]).to_owned(),
1118                covariances: covariances.into_iter().take(n_components).collect(),
1119                weight_concentration: stick_weights.slice(s![..n_components]).to_owned(),
1120                lower_bound: prev_lower_bound,
1121                n_iter,
1122                converged,
1123                n_components,
1124                n_features,
1125                covariance_type: self.covariance_type.clone(),
1126                alpha: self.alpha,
1127                reg_covar: self.reg_covar,
1128            },
1129            alpha: self.alpha,
1130            max_components: self.max_components,
1131            covariance_type: self.covariance_type,
1132            tol: self.tol,
1133            max_iter: self.max_iter,
1134            reg_covar: self.reg_covar,
1135            random_state: self.random_state,
1136            n_init: self.n_init,
1137        })
1138    }
1139}
1140
1141impl DirichletProcessGaussianMixture<Untrained> {
1142    fn compute_sample_covariance(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
1143        let (n_samples, n_features) = X.dim();
1144        let mean = X.mean_axis(Axis(0)).unwrap();
1145        let mut cov = Array2::zeros((n_features, n_features));
1146
1147        for i in 0..n_samples {
1148            let diff = &X.row(i) - &mean;
1149            for j in 0..n_features {
1150                for k in 0..n_features {
1151                    cov[[j, k]] += diff[j] * diff[k];
1152                }
1153            }
1154        }
1155        cov /= (n_samples - 1) as f64;
1156
1157        for i in 0..n_features {
1158            cov[[i, i]] += self.reg_covar;
1159        }
1160
1161        Ok(cov)
1162    }
1163
1164    fn compute_responsibilities(
1165        &self,
1166        X: &Array2<f64>,
1167        weights: &Array1<f64>,
1168        means: &Array2<f64>,
1169        covariances: &[Array2<f64>],
1170    ) -> SklResult<Array2<f64>> {
1171        let (n_samples, _) = X.dim();
1172        let n_components = weights.len();
1173        let mut responsibilities = Array2::zeros((n_samples, n_components));
1174
1175        for i in 0..n_samples {
1176            let sample = X.row(i);
1177            let mut log_probs = Array1::zeros(n_components);
1178
1179            for k in 0..n_components {
1180                let log_weight = weights[k].ln();
1181                let log_pdf =
1182                    self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
1183                log_probs[k] = log_weight + log_pdf;
1184            }
1185
1186            let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1187            let log_sum_exp = max_log_prob
1188                + log_probs
1189                    .iter()
1190                    .map(|&lp| (lp - max_log_prob).exp())
1191                    .sum::<f64>()
1192                    .ln();
1193
1194            for k in 0..n_components {
1195                responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
1196            }
1197        }
1198
1199        Ok(responsibilities)
1200    }
1201
1202    fn update_weights(&self, responsibilities: &Array2<f64>) -> SklResult<Array1<f64>> {
1203        let n_components = responsibilities.dim().1;
1204        let mut weights = Array1::zeros(n_components);
1205
1206        for k in 0..n_components {
1207            weights[k] = responsibilities.column(k).sum() / responsibilities.dim().0 as f64;
1208        }
1209
1210        Ok(weights)
1211    }
1212
1213    fn update_means(
1214        &self,
1215        X: &Array2<f64>,
1216        responsibilities: &Array2<f64>,
1217    ) -> SklResult<Array2<f64>> {
1218        let (n_samples, n_features) = X.dim();
1219        let n_components = responsibilities.dim().1;
1220        let mut means = Array2::zeros((n_components, n_features));
1221
1222        for k in 0..n_components {
1223            let weight_sum = responsibilities.column(k).sum();
1224            if weight_sum > 1e-10 {
1225                for j in 0..n_features {
1226                    let mut weighted_sum = 0.0;
1227                    for i in 0..n_samples {
1228                        weighted_sum += responsibilities[[i, k]] * X[[i, j]];
1229                    }
1230                    means[[k, j]] = weighted_sum / weight_sum;
1231                }
1232            }
1233        }
1234
1235        Ok(means)
1236    }
1237
1238    fn update_covariances(
1239        &self,
1240        X: &Array2<f64>,
1241        responsibilities: &Array2<f64>,
1242        means: &Array2<f64>,
1243    ) -> SklResult<Vec<Array2<f64>>> {
1244        let (n_samples, n_features) = X.dim();
1245        let n_components = responsibilities.dim().1;
1246        let mut covariances = Vec::new();
1247
1248        for k in 0..n_components {
1249            let weight_sum = responsibilities.column(k).sum();
1250            let mut cov = Array2::zeros((n_features, n_features));
1251
1252            if weight_sum > 1e-10 {
1253                for i in 0..n_samples {
1254                    let diff = &X.row(i) - &means.row(k);
1255                    let weight = responsibilities[[i, k]];
1256                    for j in 0..n_features {
1257                        for l in 0..n_features {
1258                            cov[[j, l]] += weight * diff[j] * diff[l];
1259                        }
1260                    }
1261                }
1262                cov /= weight_sum;
1263            } else {
1264                for j in 0..n_features {
1265                    cov[[j, j]] = 1.0;
1266                }
1267            }
1268
1269            for j in 0..n_features {
1270                cov[[j, j]] += self.reg_covar;
1271            }
1272
1273            covariances.push(cov);
1274        }
1275
1276        Ok(covariances)
1277    }
1278
1279    fn compute_lower_bound(
1280        &self,
1281        X: &Array2<f64>,
1282        responsibilities: &Array2<f64>,
1283        weights: &Array1<f64>,
1284        means: &Array2<f64>,
1285        covariances: &[Array2<f64>],
1286    ) -> SklResult<f64> {
1287        let (n_samples, _) = X.dim();
1288        let n_components = weights.len();
1289        let mut lower_bound = 0.0;
1290
1291        // Data likelihood term
1292        for i in 0..n_samples {
1293            let sample = X.row(i);
1294            for k in 0..n_components {
1295                if responsibilities[[i, k]] > 1e-10 {
1296                    let log_pdf =
1297                        self.multivariate_normal_log_pdf(&sample, &means.row(k), &covariances[k])?;
1298                    lower_bound += responsibilities[[i, k]] * (weights[k].ln() + log_pdf);
1299                }
1300            }
1301        }
1302
1303        // Entropy term
1304        for i in 0..n_samples {
1305            for k in 0..n_components {
1306                if responsibilities[[i, k]] > 1e-10 {
1307                    lower_bound -= responsibilities[[i, k]] * responsibilities[[i, k]].ln();
1308                }
1309            }
1310        }
1311
1312        Ok(lower_bound)
1313    }
1314
1315    fn multivariate_normal_log_pdf(
1316        &self,
1317        x: &ArrayView1<f64>,
1318        mean: &ArrayView1<f64>,
1319        cov: &Array2<f64>,
1320    ) -> SklResult<f64> {
1321        let d = x.len() as f64;
1322        let diff = x - mean;
1323
1324        let det = self.matrix_determinant(cov)?;
1325        if det <= 0.0 {
1326            return Ok(f64::NEG_INFINITY);
1327        }
1328
1329        let inv_cov = self.matrix_inverse(cov)?;
1330        let mut quad_form = 0.0;
1331        for i in 0..diff.len() {
1332            for j in 0..diff.len() {
1333                quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
1334            }
1335        }
1336
1337        Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
1338    }
1339
1340    fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
1341        let n = A.dim().0;
1342        if n == 1 {
1343            return Ok(A[[0, 0]]);
1344        }
1345        if n == 2 {
1346            return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
1347        }
1348
1349        let mut det = 1.0;
1350        let mut A_copy = A.clone();
1351
1352        for i in 0..n {
1353            let mut max_row = i;
1354            for k in i + 1..n {
1355                if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
1356                    max_row = k;
1357                }
1358            }
1359
1360            if max_row != i {
1361                for j in 0..n {
1362                    let temp = A_copy[[i, j]];
1363                    A_copy[[i, j]] = A_copy[[max_row, j]];
1364                    A_copy[[max_row, j]] = temp;
1365                }
1366                det *= -1.0;
1367            }
1368
1369            if A_copy[[i, i]].abs() < 1e-12 {
1370                return Ok(0.0);
1371            }
1372
1373            det *= A_copy[[i, i]];
1374
1375            for k in i + 1..n {
1376                let factor = A_copy[[k, i]] / A_copy[[i, i]];
1377                for j in i..n {
1378                    A_copy[[k, j]] -= factor * A_copy[[i, j]];
1379                }
1380            }
1381        }
1382
1383        Ok(det)
1384    }
1385
1386    fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
1387        let n = A.dim().0;
1388        let mut aug = Array2::zeros((n, 2 * n));
1389
1390        for i in 0..n {
1391            for j in 0..n {
1392                aug[[i, j]] = A[[i, j]];
1393                aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
1394            }
1395        }
1396
1397        for i in 0..n {
1398            let mut max_row = i;
1399            for k in i + 1..n {
1400                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
1401                    max_row = k;
1402                }
1403            }
1404
1405            if max_row != i {
1406                for j in 0..2 * n {
1407                    let temp = aug[[i, j]];
1408                    aug[[i, j]] = aug[[max_row, j]];
1409                    aug[[max_row, j]] = temp;
1410                }
1411            }
1412
1413            if aug[[i, i]].abs() < 1e-12 {
1414                return Err(SklearsError::NumericalError(
1415                    "Matrix is singular".to_string(),
1416                ));
1417            }
1418
1419            let pivot = aug[[i, i]];
1420            for j in 0..2 * n {
1421                aug[[i, j]] /= pivot;
1422            }
1423
1424            for k in 0..n {
1425                if k != i {
1426                    let factor = aug[[k, i]];
1427                    for j in 0..2 * n {
1428                        aug[[k, j]] -= factor * aug[[i, j]];
1429                    }
1430                }
1431            }
1432        }
1433
1434        let mut inv = Array2::zeros((n, n));
1435        for i in 0..n {
1436            for j in 0..n {
1437                inv[[i, j]] = aug[[i, j + n]];
1438            }
1439        }
1440
1441        Ok(inv)
1442    }
1443}
1444
1445/// Trained Dirichlet Process Gaussian Mixture Model
1446#[derive(Debug, Clone)]
1447pub struct DirichletProcessGaussianMixtureTrained {
1448    /// Effective mixture weights
1449    pub weights: Array1<f64>,
1450    /// Component means
1451    pub means: Array2<f64>,
1452    /// Component covariances
1453    pub covariances: Vec<Array2<f64>>,
1454    /// Variational parameters for stick-breaking weights
1455    pub weight_concentration: Array1<f64>,
1456    /// Lower bound on log-likelihood
1457    pub lower_bound: f64,
1458    /// Number of iterations performed
1459    pub n_iter: usize,
1460    /// Whether the algorithm converged
1461    pub converged: bool,
1462    /// Number of effective components
1463    pub n_components: usize,
1464    /// Number of features
1465    pub n_features: usize,
1466    /// Covariance type
1467    pub covariance_type: CovarianceType,
1468    /// Concentration parameter
1469    pub alpha: f64,
1470    /// Regularization parameter
1471    pub reg_covar: f64,
1472}
1473
1474impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1475    for DirichletProcessGaussianMixture<DirichletProcessGaussianMixtureTrained>
1476{
1477    #[allow(non_snake_case)]
1478    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1479        let X = X.to_owned();
1480        let (n_samples, _) = X.dim();
1481        let mut predictions = Array1::zeros(n_samples);
1482
1483        for i in 0..n_samples {
1484            let sample = X.row(i);
1485            let mut max_log_prob = f64::NEG_INFINITY;
1486            let mut best_component = 0;
1487
1488            for k in 0..self.state.n_components {
1489                let log_weight = self.state.weights[k].ln();
1490                let log_pdf = self.multivariate_normal_log_pdf(
1491                    &sample,
1492                    &self.state.means.row(k),
1493                    &self.state.covariances[k],
1494                )?;
1495                let log_prob = log_weight + log_pdf;
1496
1497                if log_prob > max_log_prob {
1498                    max_log_prob = log_prob;
1499                    best_component = k;
1500                }
1501            }
1502
1503            predictions[i] = best_component as i32;
1504        }
1505
1506        Ok(predictions)
1507    }
1508}
1509
1510impl DirichletProcessGaussianMixture<DirichletProcessGaussianMixtureTrained> {
1511    /// Predict class probabilities
1512    #[allow(non_snake_case)]
1513    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1514        let X = X.to_owned();
1515        let (n_samples, _) = X.dim();
1516        let mut probabilities = Array2::zeros((n_samples, self.state.n_components));
1517
1518        for i in 0..n_samples {
1519            let sample = X.row(i);
1520            let mut log_probs = Array1::zeros(self.state.n_components);
1521
1522            for k in 0..self.state.n_components {
1523                let log_weight = self.state.weights[k].ln();
1524                let log_pdf = self.multivariate_normal_log_pdf(
1525                    &sample,
1526                    &self.state.means.row(k),
1527                    &self.state.covariances[k],
1528                )?;
1529                log_probs[k] = log_weight + log_pdf;
1530            }
1531
1532            let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1533            let log_sum_exp = max_log_prob
1534                + log_probs
1535                    .iter()
1536                    .map(|&lp| (lp - max_log_prob).exp())
1537                    .sum::<f64>()
1538                    .ln();
1539
1540            for k in 0..self.state.n_components {
1541                probabilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
1542            }
1543        }
1544
1545        Ok(probabilities)
1546    }
1547
1548    /// Score samples using the lower bound
1549    pub fn score(&self, _X: &ArrayView2<'_, Float>) -> SklResult<f64> {
1550        Ok(self.state.lower_bound)
1551    }
1552
1553    fn multivariate_normal_log_pdf(
1554        &self,
1555        x: &ArrayView1<f64>,
1556        mean: &ArrayView1<f64>,
1557        cov: &Array2<f64>,
1558    ) -> SklResult<f64> {
1559        let d = x.len() as f64;
1560        let diff = x - mean;
1561
1562        let det = self.matrix_determinant(cov)?;
1563        if det <= 0.0 {
1564            return Ok(f64::NEG_INFINITY);
1565        }
1566
1567        let inv_cov = self.matrix_inverse(cov)?;
1568        let mut quad_form = 0.0;
1569        for i in 0..diff.len() {
1570            for j in 0..diff.len() {
1571                quad_form += diff[i] * inv_cov[[i, j]] * diff[j];
1572            }
1573        }
1574
1575        Ok(-0.5 * (d * (2.0 * PI).ln() + det.ln() + quad_form))
1576    }
1577
1578    fn matrix_determinant(&self, A: &Array2<f64>) -> SklResult<f64> {
1579        let n = A.dim().0;
1580        if n == 1 {
1581            return Ok(A[[0, 0]]);
1582        }
1583        if n == 2 {
1584            return Ok(A[[0, 0]] * A[[1, 1]] - A[[0, 1]] * A[[1, 0]]);
1585        }
1586
1587        let mut det = 1.0;
1588        let mut A_copy = A.clone();
1589
1590        for i in 0..n {
1591            let mut max_row = i;
1592            for k in i + 1..n {
1593                if A_copy[[k, i]].abs() > A_copy[[max_row, i]].abs() {
1594                    max_row = k;
1595                }
1596            }
1597
1598            if max_row != i {
1599                for j in 0..n {
1600                    let temp = A_copy[[i, j]];
1601                    A_copy[[i, j]] = A_copy[[max_row, j]];
1602                    A_copy[[max_row, j]] = temp;
1603                }
1604                det *= -1.0;
1605            }
1606
1607            if A_copy[[i, i]].abs() < 1e-12 {
1608                return Ok(0.0);
1609            }
1610
1611            det *= A_copy[[i, i]];
1612
1613            for k in i + 1..n {
1614                let factor = A_copy[[k, i]] / A_copy[[i, i]];
1615                for j in i..n {
1616                    A_copy[[k, j]] -= factor * A_copy[[i, j]];
1617                }
1618            }
1619        }
1620
1621        Ok(det)
1622    }
1623
1624    fn matrix_inverse(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
1625        let n = A.dim().0;
1626        let mut aug = Array2::zeros((n, 2 * n));
1627
1628        for i in 0..n {
1629            for j in 0..n {
1630                aug[[i, j]] = A[[i, j]];
1631                aug[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
1632            }
1633        }
1634
1635        for i in 0..n {
1636            let mut max_row = i;
1637            for k in i + 1..n {
1638                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
1639                    max_row = k;
1640                }
1641            }
1642
1643            if max_row != i {
1644                for j in 0..2 * n {
1645                    let temp = aug[[i, j]];
1646                    aug[[i, j]] = aug[[max_row, j]];
1647                    aug[[max_row, j]] = temp;
1648                }
1649            }
1650
1651            if aug[[i, i]].abs() < 1e-12 {
1652                return Err(SklearsError::NumericalError(
1653                    "Matrix is singular".to_string(),
1654                ));
1655            }
1656
1657            let pivot = aug[[i, i]];
1658            for j in 0..2 * n {
1659                aug[[i, j]] /= pivot;
1660            }
1661
1662            for k in 0..n {
1663                if k != i {
1664                    let factor = aug[[k, i]];
1665                    for j in 0..2 * n {
1666                        aug[[k, j]] -= factor * aug[[i, j]];
1667                    }
1668                }
1669            }
1670        }
1671
1672        let mut inv = Array2::zeros((n, n));
1673        for i in 0..n {
1674            for j in 0..n {
1675                inv[[i, j]] = aug[[i, j + n]];
1676            }
1677        }
1678
1679        Ok(inv)
1680    }
1681}
1682
1683// Note: The Pitman-Yor Process and Hierarchical Dirichlet Process implementations
1684// would follow similar patterns but are omitted for brevity. They would include:
1685//
1686// 1. PitmanYorProcess<S> and PitmanYorProcessTrained structs
1687// 2. HierarchicalDirichletProcess<S> and HierarchicalDirichletProcessTrained structs
1688// 3. Similar trait implementations (Estimator, Fit, Predict)
1689// 4. Specialized stick-breaking constructions for each model
1690// 5. Appropriate inference algorithms (variational inference, Gibbs sampling)
1691//
1692// These can be added as needed based on specific requirements.