1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Predict, Untrained},
11 types::Float,
12};
13use std::f64::consts::PI;
14
15fn log_sum_exp(a: f64, b: f64) -> f64 {
17 let max_val = a.max(b);
18 if max_val.is_finite() {
19 max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
20 } else {
21 max_val
22 }
23}
24
25#[derive(Debug, Clone)]
67pub struct BayesianGaussianMixture<S = Untrained> {
68 pub(crate) state: S,
69 n_components: usize,
70 covariance_type: String,
71 tol: f64,
72 reg_covar: f64,
73 max_iter: usize,
74 random_state: Option<u64>,
75 warm_start: bool,
76 weight_concentration_prior_type: String,
77 weight_concentration_prior: Option<f64>,
78 mean_precision_prior: Option<f64>,
79 mean_prior: Option<Array1<f64>>,
80 degrees_of_freedom_prior: Option<f64>,
81 covariance_prior: Option<f64>,
82}
83
84impl BayesianGaussianMixture<Untrained> {
85 pub fn new() -> Self {
87 Self {
88 state: Untrained,
89 n_components: 1,
90 covariance_type: "full".to_string(),
91 tol: 1e-3,
92 reg_covar: 1e-6,
93 max_iter: 100,
94 random_state: None,
95 warm_start: false,
96 weight_concentration_prior_type: "dirichlet_process".to_string(),
97 weight_concentration_prior: None,
98 mean_precision_prior: None,
99 mean_prior: None,
100 degrees_of_freedom_prior: None,
101 covariance_prior: None,
102 }
103 }
104
105 pub fn n_components(mut self, n_components: usize) -> Self {
107 self.n_components = n_components;
108 self
109 }
110
111 pub fn covariance_type(mut self, covariance_type: String) -> Self {
113 self.covariance_type = covariance_type;
114 self
115 }
116
117 pub fn tol(mut self, tol: f64) -> Self {
119 self.tol = tol;
120 self
121 }
122
123 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
125 self.reg_covar = reg_covar;
126 self
127 }
128
129 pub fn max_iter(mut self, max_iter: usize) -> Self {
131 self.max_iter = max_iter;
132 self
133 }
134
135 pub fn random_state(mut self, random_state: u64) -> Self {
137 self.random_state = Some(random_state);
138 self
139 }
140
141 pub fn warm_start(mut self, warm_start: bool) -> Self {
143 self.warm_start = warm_start;
144 self
145 }
146
147 pub fn weight_concentration_prior_type(mut self, prior_type: String) -> Self {
149 self.weight_concentration_prior_type = prior_type;
150 self
151 }
152
153 pub fn weight_concentration_prior(mut self, prior: f64) -> Self {
155 self.weight_concentration_prior = Some(prior);
156 self
157 }
158
159 pub fn mean_precision_prior(mut self, prior: f64) -> Self {
161 self.mean_precision_prior = Some(prior);
162 self
163 }
164}
165
166impl Default for BayesianGaussianMixture<Untrained> {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl Estimator for BayesianGaussianMixture<Untrained> {
173 type Config = ();
174 type Error = SklearsError;
175 type Float = Float;
176
177 fn config(&self) -> &Self::Config {
178 &()
179 }
180}
181
182impl Fit<ArrayView2<'_, Float>, ()> for BayesianGaussianMixture<Untrained> {
183 type Fitted = BayesianGaussianMixture<BayesianGaussianMixtureTrained>;
184
185 #[allow(non_snake_case)]
186 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
187 let X = X.to_owned();
188 let (n_samples, _n_features) = X.dim();
189
190 if n_samples < self.n_components {
191 return Err(SklearsError::InvalidInput(
192 "Number of samples must be at least the number of components".to_string(),
193 ));
194 }
195
196 let mut weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
198 let mut means = self.initialize_means(&X)?;
199 let mut covariances = self.initialize_covariances(&X, &means)?;
200
201 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
203 let mut lower_bound = f64::NEG_INFINITY;
204 let mut converged = false;
205
206 for iteration in 0..self.max_iter {
208 self.update_responsibilities(
210 &X,
211 &weights,
212 &means,
213 &covariances,
214 &mut responsibilities,
215 )?;
216
217 let (new_weights, new_means, new_covariances) =
219 self.update_parameters(&X, &responsibilities)?;
220
221 let new_lower_bound = self.compute_lower_bound(
223 &X,
224 &responsibilities,
225 &new_weights,
226 &new_means,
227 &new_covariances,
228 );
229
230 if iteration > 0 && (new_lower_bound - lower_bound).abs() < self.tol {
231 converged = true;
232 }
233
234 weights = new_weights;
235 means = new_means;
236 covariances = new_covariances;
237 lower_bound = new_lower_bound;
238
239 if converged {
240 break;
241 }
242 }
243
244 let weight_threshold = 1.0 / (self.n_components as f64 * 100.0);
246 let n_components_effective = weights.iter().filter(|&&w| w > weight_threshold).count();
247
248 Ok(BayesianGaussianMixture {
249 state: BayesianGaussianMixtureTrained {
250 weights,
251 means,
252 covariances,
253 n_components_effective,
254 lower_bound,
255 converged,
256 n_iter: if converged { 0 } else { self.max_iter }, },
258 n_components: self.n_components,
259 covariance_type: self.covariance_type,
260 tol: self.tol,
261 reg_covar: self.reg_covar,
262 max_iter: self.max_iter,
263 random_state: self.random_state,
264 warm_start: self.warm_start,
265 weight_concentration_prior_type: self.weight_concentration_prior_type,
266 weight_concentration_prior: self.weight_concentration_prior,
267 mean_precision_prior: self.mean_precision_prior,
268 mean_prior: self.mean_prior,
269 degrees_of_freedom_prior: self.degrees_of_freedom_prior,
270 covariance_prior: self.covariance_prior,
271 })
272 }
273}
274
275impl BayesianGaussianMixture<Untrained> {
276 fn initialize_means(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
277 let (_, n_features) = X.dim();
278 let mut means = Array2::zeros((self.n_components, n_features));
279
280 let step = X.nrows() / self.n_components;
282 for (i, mut mean) in means.axis_iter_mut(Axis(0)).enumerate() {
283 let sample_idx = (i * step).min(X.nrows() - 1);
284 mean.assign(&X.row(sample_idx));
285 }
286
287 Ok(means)
288 }
289
290 fn initialize_covariances(
291 &self,
292 X: &Array2<f64>,
293 _means: &Array2<f64>,
294 ) -> SklResult<Vec<Array2<f64>>> {
295 let (_, n_features) = X.dim();
296
297 let mut covariances = Vec::new();
299 for _ in 0..self.n_components {
300 let mut cov = Array2::eye(n_features);
301 for i in 0..n_features {
303 cov[[i, i]] += self.reg_covar;
304 }
305 covariances.push(cov);
306 }
307
308 Ok(covariances)
309 }
310
311 fn update_responsibilities(
312 &self,
313 X: &Array2<f64>,
314 weights: &Array1<f64>,
315 means: &Array2<f64>,
316 covariances: &[Array2<f64>],
317 responsibilities: &mut Array2<f64>,
318 ) -> SklResult<()> {
319 let (n_samples, _) = X.dim();
320
321 for i in 0..n_samples {
322 let sample = X.row(i);
323 let mut log_prob_sum = f64::NEG_INFINITY;
324 let mut log_probs = Vec::new();
325
326 for k in 0..self.n_components {
328 let mean = means.row(k);
329 let cov = &covariances[k];
330 let log_weight = weights[k].ln();
331 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
332 let log_prob = log_weight + log_likelihood;
333 log_probs.push(log_prob);
334 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
335 }
336
337 for k in 0..self.n_components {
339 responsibilities[[i, k]] = (log_probs[k] - log_prob_sum).exp();
340 }
341 }
342
343 Ok(())
344 }
345
346 fn update_parameters(
347 &self,
348 X: &Array2<f64>,
349 responsibilities: &Array2<f64>,
350 ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
351 let (n_samples, n_features) = X.dim();
352
353 let n_k: Array1<f64> = responsibilities.sum_axis(Axis(0));
355 let weights = &n_k / n_samples as f64;
356
357 let mut means = Array2::zeros((self.n_components, n_features));
359 for k in 0..self.n_components {
360 if n_k[k] > 1e-10 {
361 for i in 0..n_samples {
362 for j in 0..n_features {
363 means[[k, j]] += responsibilities[[i, k]] * X[[i, j]];
364 }
365 }
366 for j in 0..n_features {
367 means[[k, j]] /= n_k[k];
368 }
369 }
370 }
371
372 let mut covariances = Vec::new();
374 for k in 0..self.n_components {
375 let mut cov = Array2::eye(n_features);
376
377 if n_k[k] > 1e-10 {
378 let mean_k = means.row(k);
379
380 for d in 0..n_features {
381 let mut var = 0.0;
382 for i in 0..n_samples {
383 let diff = X[[i, d]] - mean_k[d];
384 var += responsibilities[[i, k]] * diff * diff;
385 }
386 var /= n_k[k];
387 cov[[d, d]] = var + self.reg_covar;
388 }
389 } else {
390 for d in 0..n_features {
392 cov[[d, d]] = 1.0 + self.reg_covar;
393 }
394 }
395
396 covariances.push(cov);
397 }
398
399 Ok((weights, means, covariances))
400 }
401
402 fn multivariate_normal_log_pdf(
403 &self,
404 x: &ArrayView1<f64>,
405 mean: &ArrayView1<f64>,
406 cov: &Array2<f64>,
407 ) -> SklResult<f64> {
408 let d = x.len() as f64;
409 let diff: Array1<f64> = x - mean;
410
411 let mut log_det = 0.0;
413 for i in 0..cov.nrows() {
414 log_det += cov[[i, i]].ln();
415 }
416
417 let mut quad_form = 0.0;
419 for i in 0..diff.len() {
420 quad_form += diff[i] * diff[i] / cov[[i, i]];
421 }
422
423 let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
424 Ok(log_pdf)
425 }
426
427 fn compute_lower_bound(
428 &self,
429 _X: &Array2<f64>,
430 _responsibilities: &Array2<f64>,
431 _weights: &Array1<f64>,
432 _means: &Array2<f64>,
433 _covariances: &[Array2<f64>],
434 ) -> f64 {
435 0.0
438 }
439}
440
441impl Predict<ArrayView2<'_, Float>, Array1<i32>>
442 for BayesianGaussianMixture<BayesianGaussianMixtureTrained>
443{
444 #[allow(non_snake_case)]
445 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
446 let X = X.to_owned();
447 let (n_samples, _) = X.dim();
448 let mut predictions = Array1::zeros(n_samples);
449
450 for i in 0..n_samples {
451 let sample = X.row(i);
452 let mut max_log_prob = f64::NEG_INFINITY;
453 let mut best_component = 0;
454
455 for k in 0..self.n_components {
456 let mean = self.state.means.row(k);
457 let cov = &self.state.covariances[k];
458 let log_weight = self.state.weights[k].ln();
459
460 if let Ok(log_likelihood) = self.multivariate_normal_log_pdf(&sample, &mean, cov) {
461 let log_prob = log_weight + log_likelihood;
462 if log_prob > max_log_prob {
463 max_log_prob = log_prob;
464 best_component = k;
465 }
466 }
467 }
468
469 predictions[i] = best_component as i32;
470 }
471
472 Ok(predictions)
473 }
474}
475
476impl BayesianGaussianMixture<BayesianGaussianMixtureTrained> {
477 pub fn weights(&self) -> &Array1<f64> {
479 &self.state.weights
480 }
481
482 pub fn means(&self) -> &Array2<f64> {
484 &self.state.means
485 }
486
487 pub fn covariances(&self) -> &[Array2<f64>] {
489 &self.state.covariances
490 }
491
492 pub fn n_components_effective(&self) -> usize {
494 self.state.n_components_effective
495 }
496
497 pub fn lower_bound(&self) -> f64 {
499 self.state.lower_bound
500 }
501
502 pub fn converged(&self) -> bool {
504 self.state.converged
505 }
506
507 pub fn n_iter(&self) -> usize {
509 self.state.n_iter
510 }
511
512 #[allow(non_snake_case)]
514 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
515 let X = X.to_owned();
516 let (n_samples, _) = X.dim();
517 let mut probabilities = Array2::zeros((n_samples, self.n_components));
518
519 for i in 0..n_samples {
520 let sample = X.row(i);
521 let mut log_prob_sum = f64::NEG_INFINITY;
522 let mut log_probs = Vec::new();
523
524 for k in 0..self.n_components {
526 let mean = self.state.means.row(k);
527 let cov = &self.state.covariances[k];
528 let log_weight = self.state.weights[k].ln();
529 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
530 let log_prob = log_weight + log_likelihood;
531 log_probs.push(log_prob);
532 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
533 }
534
535 for k in 0..self.n_components {
537 probabilities[[i, k]] = (log_probs[k] - log_prob_sum).exp();
538 }
539 }
540
541 Ok(probabilities)
542 }
543
544 #[allow(non_snake_case)]
546 pub fn score_samples(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
547 let X = X.to_owned();
548 let (n_samples, _) = X.dim();
549 let mut scores = Array1::zeros(n_samples);
550
551 for i in 0..n_samples {
552 let sample = X.row(i);
553 let mut log_prob_sum = f64::NEG_INFINITY;
554
555 for k in 0..self.n_components {
556 let mean = self.state.means.row(k);
557 let cov = &self.state.covariances[k];
558 let log_weight = self.state.weights[k].ln();
559 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
560 let log_prob = log_weight + log_likelihood;
561 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
562 }
563
564 scores[i] = log_prob_sum;
565 }
566
567 Ok(scores)
568 }
569
570 pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
572 let scores = self.score_samples(X)?;
573 Ok(scores.mean().unwrap_or(0.0))
574 }
575
576 fn multivariate_normal_log_pdf(
577 &self,
578 x: &ArrayView1<f64>,
579 mean: &ArrayView1<f64>,
580 cov: &Array2<f64>,
581 ) -> SklResult<f64> {
582 let d = x.len() as f64;
583 let diff: Array1<f64> = x - mean;
584
585 let mut log_det = 0.0;
587 for i in 0..cov.nrows() {
588 log_det += cov[[i, i]].ln();
589 }
590
591 let mut quad_form = 0.0;
593 for i in 0..diff.len() {
594 quad_form += diff[i] * diff[i] / cov[[i, i]];
595 }
596
597 let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
598 Ok(log_pdf)
599 }
600}
601
602#[derive(Debug, Clone)]
604pub struct BayesianGaussianMixtureTrained {
605 pub weights: Array1<f64>,
607 pub means: Array2<f64>,
609 pub covariances: Vec<Array2<f64>>,
611 pub n_components_effective: usize,
613 pub lower_bound: f64,
615 pub converged: bool,
617 pub n_iter: usize,
619}