sklears_semi_supervised/
bayesian_methods.rs

1//! Bayesian methods for semi-supervised learning
2//!
3//! This module provides Bayesian approaches to semi-supervised learning,
4//! including Gaussian process methods, variational inference, and
5//! hierarchical Bayesian models for learning with both labeled and unlabeled data.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::error::{Result as SklResult, SklearsError};
10use sklears_core::traits::{Estimator, Fit, Predict, PredictProba, Untrained};
11use sklears_core::types::Float;
12use std::f64::consts::PI;
13
14/// Gaussian Process Semi-Supervised Learning
15///
16/// This method uses Gaussian processes to perform semi-supervised learning
17/// by treating the labeled samples as observed function values and inferring
18/// the function values at unlabeled points through GP inference.
19///
20/// # Parameters
21///
22/// * `kernel` - Kernel function ('rbf', 'linear', 'polynomial')
23/// * `length_scale` - Length scale parameter for RBF kernel
24/// * `noise_level` - Noise level (variance) for observations
25/// * `alpha` - Regularization parameter
26/// * `n_restarts_optimizer` - Number of restarts for hyperparameter optimization
27/// * `random_state` - Random seed for reproducibility
28///
29/// # Examples
30///
31/// ```rust,ignore
32/// use sklears_semi_supervised::GaussianProcessSemiSupervised;
33/// use sklears_core::traits::{Predict, Fit};
34///
35///
36/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
37/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
38///
39/// let gp = GaussianProcessSemiSupervised::new()
40///     .kernel("rbf".to_string())
41///     .length_scale(1.0)
42///     .noise_level(0.1);
43/// let fitted = gp.fit(&X.view(), &y.view()).unwrap();
44/// let predictions = fitted.predict(&X.view()).unwrap();
45/// ```
46#[derive(Debug, Clone)]
47pub struct GaussianProcessSemiSupervised<S = Untrained> {
48    state: S,
49    kernel: String,
50    length_scale: f64,
51    noise_level: f64,
52    alpha: f64,
53    n_restarts_optimizer: usize,
54    random_state: Option<u64>,
55}
56
57impl GaussianProcessSemiSupervised<Untrained> {
58    /// Create a new GaussianProcessSemiSupervised instance
59    pub fn new() -> Self {
60        Self {
61            state: Untrained,
62            kernel: "rbf".to_string(),
63            length_scale: 1.0,
64            noise_level: 0.1,
65            alpha: 1e-10,
66            n_restarts_optimizer: 0,
67            random_state: None,
68        }
69    }
70
71    /// Set the kernel function
72    pub fn kernel(mut self, kernel: String) -> Self {
73        self.kernel = kernel;
74        self
75    }
76
77    /// Set the length scale parameter
78    pub fn length_scale(mut self, length_scale: f64) -> Self {
79        self.length_scale = length_scale;
80        self
81    }
82
83    /// Set the noise level
84    pub fn noise_level(mut self, noise_level: f64) -> Self {
85        self.noise_level = noise_level;
86        self
87    }
88
89    /// Set the alpha regularization parameter
90    pub fn alpha(mut self, alpha: f64) -> Self {
91        self.alpha = alpha;
92        self
93    }
94
95    /// Set the number of optimizer restarts
96    pub fn n_restarts_optimizer(mut self, n_restarts: usize) -> Self {
97        self.n_restarts_optimizer = n_restarts;
98        self
99    }
100
101    /// Set the random state
102    pub fn random_state(mut self, random_state: u64) -> Self {
103        self.random_state = Some(random_state);
104        self
105    }
106}
107
108impl Default for GaussianProcessSemiSupervised<Untrained> {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl Estimator for GaussianProcessSemiSupervised<Untrained> {
115    type Config = ();
116    type Error = SklearsError;
117    type Float = Float;
118
119    fn config(&self) -> &Self::Config {
120        &()
121    }
122}
123
124impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for GaussianProcessSemiSupervised<Untrained> {
125    type Fitted = GaussianProcessSemiSupervised<GaussianProcessTrained>;
126
127    #[allow(non_snake_case)]
128    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
129        let X = X.to_owned();
130        let y = y.to_owned();
131        let (n_samples, _n_features) = X.dim();
132
133        // Identify labeled and unlabeled samples
134        let mut labeled_indices = Vec::new();
135        let mut unlabeled_indices = Vec::new();
136        let mut classes = std::collections::HashSet::new();
137
138        for (i, &label) in y.iter().enumerate() {
139            if label == -1 {
140                unlabeled_indices.push(i);
141            } else {
142                labeled_indices.push(i);
143                classes.insert(label);
144            }
145        }
146
147        if labeled_indices.is_empty() {
148            return Err(SklearsError::InvalidInput(
149                "No labeled samples provided".to_string(),
150            ));
151        }
152
153        let classes: Vec<i32> = classes.into_iter().collect();
154        let n_classes = classes.len();
155
156        // Convert class labels to regression targets for GP
157        let mut regression_targets = Array2::<f64>::zeros((labeled_indices.len(), n_classes));
158        for (i, &idx) in labeled_indices.iter().enumerate() {
159            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
160                regression_targets[[i, class_idx]] = 1.0;
161            }
162        }
163
164        // Extract labeled features
165        let mut X_labeled = Array2::<f64>::zeros((labeled_indices.len(), X.ncols()));
166        for (i, &idx) in labeled_indices.iter().enumerate() {
167            X_labeled.row_mut(i).assign(&X.row(idx));
168        }
169
170        // Compute kernel matrix for labeled points
171        let K_labeled = self.compute_kernel_matrix(&X_labeled, &X_labeled)?;
172
173        // Add noise to diagonal
174        let mut K_noise = K_labeled.clone();
175        for i in 0..K_noise.nrows() {
176            K_noise[[i, i]] += self.noise_level + self.alpha;
177        }
178
179        // Solve for GP weights (simplified - would need proper Cholesky decomposition)
180        let GP_weights = self.solve_gp_system(&K_noise, &regression_targets)?;
181
182        // Predict for all points (including unlabeled)
183        let K_all = self.compute_kernel_matrix(&X, &X_labeled)?;
184        let predictions_all = K_all.dot(&GP_weights);
185
186        // Generate final labels for unlabeled samples
187        let mut final_labels = y.clone();
188        for &idx in &unlabeled_indices {
189            let class_idx = predictions_all
190                .row(idx)
191                .iter()
192                .enumerate()
193                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
194                .unwrap()
195                .0;
196            final_labels[idx] = classes[class_idx];
197        }
198
199        Ok(GaussianProcessSemiSupervised {
200            state: GaussianProcessTrained {
201                X_train: X,
202                y_train: final_labels,
203                classes: Array1::from(classes),
204                X_labeled,
205                GP_weights,
206                predictions_all,
207            },
208            kernel: self.kernel,
209            length_scale: self.length_scale,
210            noise_level: self.noise_level,
211            alpha: self.alpha,
212            n_restarts_optimizer: self.n_restarts_optimizer,
213            random_state: self.random_state,
214        })
215    }
216}
217
218impl GaussianProcessSemiSupervised<Untrained> {
219    /// Compute kernel matrix between two sets of points
220    fn compute_kernel_matrix(&self, X1: &Array2<f64>, X2: &Array2<f64>) -> SklResult<Array2<f64>> {
221        let n1 = X1.nrows();
222        let n2 = X2.nrows();
223        let mut K = Array2::<f64>::zeros((n1, n2));
224
225        match self.kernel.as_str() {
226            "rbf" => {
227                for i in 0..n1 {
228                    for j in 0..n2 {
229                        let diff = &X1.row(i) - &X2.row(j);
230                        let dist_sq = diff.mapv(|x| x * x).sum();
231                        K[[i, j]] = (-dist_sq / (2.0 * self.length_scale.powi(2))).exp();
232                    }
233                }
234            }
235            "linear" => {
236                for i in 0..n1 {
237                    for j in 0..n2 {
238                        K[[i, j]] = X1.row(i).dot(&X2.row(j));
239                    }
240                }
241            }
242            "polynomial" => {
243                let degree = 2.0;
244                for i in 0..n1 {
245                    for j in 0..n2 {
246                        let dot_product = X1.row(i).dot(&X2.row(j));
247                        K[[i, j]] = (1.0 + dot_product / self.length_scale).powf(degree);
248                    }
249                }
250            }
251            _ => {
252                return Err(SklearsError::InvalidInput(format!(
253                    "Unknown kernel: {}",
254                    self.kernel
255                )));
256            }
257        }
258
259        Ok(K)
260    }
261
262    /// Solve GP system (simplified - would need proper numerical methods)
263    fn solve_gp_system(&self, K: &Array2<f64>, targets: &Array2<f64>) -> SklResult<Array2<f64>> {
264        let n = K.nrows();
265        let n_targets = targets.ncols();
266
267        // Simplified solution using pseudo-inverse approach
268        // In practice, would use Cholesky decomposition for numerical stability
269        let mut weights = Array2::<f64>::zeros((n, n_targets));
270
271        for target_idx in 0..n_targets {
272            // Simple iterative solution (Jacobi method)
273            let mut x = Array1::<f64>::zeros(n);
274            let target_col = targets.column(target_idx);
275
276            for _iter in 0..100 {
277                let mut x_new = Array1::<f64>::zeros(n);
278
279                for i in 0..n {
280                    let mut sum = 0.0;
281                    for j in 0..n {
282                        if i != j {
283                            sum += K[[i, j]] * x[j];
284                        }
285                    }
286                    x_new[i] = (target_col[i] - sum) / K[[i, i]];
287                }
288
289                // Check convergence
290                let diff = (&x_new - &x).mapv(|x| x.abs()).sum();
291                if diff < 1e-6 {
292                    break;
293                }
294                x = x_new;
295            }
296
297            // Store the solution
298            for i in 0..n {
299                weights[[i, target_idx]] = x[i];
300            }
301        }
302
303        Ok(weights)
304    }
305}
306
307impl Predict<ArrayView2<'_, Float>, Array1<i32>>
308    for GaussianProcessSemiSupervised<GaussianProcessTrained>
309{
310    #[allow(non_snake_case)]
311    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
312        let X = X.to_owned();
313        let n_test = X.nrows();
314        let mut predictions = Array1::zeros(n_test);
315
316        // Compute kernel between test points and training points
317        let K_test = self
318            .compute_kernel_matrix(&X, &self.state.X_labeled)
319            .map_err(|e| SklearsError::PredictError(e.to_string()))?;
320
321        // GP prediction
322        let gp_predictions = K_test.dot(&self.state.GP_weights);
323
324        for i in 0..n_test {
325            // Find class with highest GP prediction
326            let class_idx = gp_predictions
327                .row(i)
328                .iter()
329                .enumerate()
330                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
331                .unwrap()
332                .0;
333            predictions[i] = self.state.classes[class_idx];
334        }
335
336        Ok(predictions)
337    }
338}
339
340impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
341    for GaussianProcessSemiSupervised<GaussianProcessTrained>
342{
343    #[allow(non_snake_case)]
344    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
345        let X = X.to_owned();
346        let n_test = X.nrows();
347        let n_classes = self.state.classes.len();
348
349        // Compute kernel between test points and training points
350        let K_test = self
351            .compute_kernel_matrix(&X, &self.state.X_labeled)
352            .map_err(|e| SklearsError::PredictError(e.to_string()))?;
353
354        // GP prediction
355        let gp_predictions = K_test.dot(&self.state.GP_weights);
356
357        // Convert to probabilities using softmax
358        let mut probabilities = Array2::<f64>::zeros((n_test, n_classes));
359        for i in 0..n_test {
360            let row = gp_predictions.row(i);
361
362            // Softmax transformation
363            let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
364            let mut exp_sum = 0.0;
365
366            for j in 0..n_classes {
367                let exp_val = (row[j] - max_val).exp();
368                probabilities[[i, j]] = exp_val;
369                exp_sum += exp_val;
370            }
371
372            // Normalize
373            if exp_sum > 0.0 {
374                for j in 0..n_classes {
375                    probabilities[[i, j]] /= exp_sum;
376                }
377            } else {
378                // Uniform distribution if all values are the same
379                for j in 0..n_classes {
380                    probabilities[[i, j]] = 1.0 / n_classes as f64;
381                }
382            }
383        }
384
385        Ok(probabilities)
386    }
387}
388
389impl GaussianProcessSemiSupervised<GaussianProcessTrained> {
390    /// Compute kernel matrix between two sets of points (for prediction)
391    fn compute_kernel_matrix(&self, X1: &Array2<f64>, X2: &Array2<f64>) -> SklResult<Array2<f64>> {
392        let n1 = X1.nrows();
393        let n2 = X2.nrows();
394        let mut K = Array2::<f64>::zeros((n1, n2));
395
396        match self.kernel.as_str() {
397            "rbf" => {
398                for i in 0..n1 {
399                    for j in 0..n2 {
400                        let diff = &X1.row(i) - &X2.row(j);
401                        let dist_sq = diff.mapv(|x| x * x).sum();
402                        K[[i, j]] = (-dist_sq / (2.0 * self.length_scale.powi(2))).exp();
403                    }
404                }
405            }
406            "linear" => {
407                for i in 0..n1 {
408                    for j in 0..n2 {
409                        K[[i, j]] = X1.row(i).dot(&X2.row(j));
410                    }
411                }
412            }
413            "polynomial" => {
414                let degree = 2.0;
415                for i in 0..n1 {
416                    for j in 0..n2 {
417                        let dot_product = X1.row(i).dot(&X2.row(j));
418                        K[[i, j]] = (1.0 + dot_product / self.length_scale).powf(degree);
419                    }
420                }
421            }
422            _ => {
423                return Err(SklearsError::InvalidInput(format!(
424                    "Unknown kernel: {}",
425                    self.kernel
426                )));
427            }
428        }
429
430        Ok(K)
431    }
432}
433
434/// Variational Bayesian Semi-Supervised Learning
435///
436/// This method uses variational inference to learn from both labeled and unlabeled data
437/// by maximizing a variational lower bound on the log-likelihood.
438///
439/// # Parameters
440///
441/// * `n_components` - Number of mixture components
442/// * `max_iter` - Maximum number of iterations
443/// * `tol` - Convergence tolerance
444/// * `reg_covar` - Regularization for covariance matrices
445/// * `alpha_prior` - Prior concentration parameter
446#[derive(Debug, Clone)]
447pub struct VariationalBayesianSemiSupervised<S = Untrained> {
448    state: S,
449    n_components: usize,
450    max_iter: usize,
451    tol: f64,
452    reg_covar: f64,
453    alpha_prior: f64,
454    random_state: Option<u64>,
455}
456
457impl VariationalBayesianSemiSupervised<Untrained> {
458    /// Create a new VariationalBayesianSemiSupervised instance
459    pub fn new() -> Self {
460        Self {
461            state: Untrained,
462            n_components: 2,
463            max_iter: 100,
464            tol: 1e-4,
465            reg_covar: 1e-6,
466            alpha_prior: 1.0,
467            random_state: None,
468        }
469    }
470
471    /// Set the number of components
472    pub fn n_components(mut self, n_components: usize) -> Self {
473        self.n_components = n_components;
474        self
475    }
476
477    /// Set the maximum number of iterations
478    pub fn max_iter(mut self, max_iter: usize) -> Self {
479        self.max_iter = max_iter;
480        self
481    }
482
483    /// Set the convergence tolerance
484    pub fn tol(mut self, tol: f64) -> Self {
485        self.tol = tol;
486        self
487    }
488
489    /// Set the covariance regularization
490    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
491        self.reg_covar = reg_covar;
492        self
493    }
494
495    /// Set the alpha prior
496    pub fn alpha_prior(mut self, alpha_prior: f64) -> Self {
497        self.alpha_prior = alpha_prior;
498        self
499    }
500
501    /// Set the random state
502    pub fn random_state(mut self, random_state: u64) -> Self {
503        self.random_state = Some(random_state);
504        self
505    }
506}
507
508impl Default for VariationalBayesianSemiSupervised<Untrained> {
509    fn default() -> Self {
510        Self::new()
511    }
512}
513
514impl Estimator for VariationalBayesianSemiSupervised<Untrained> {
515    type Config = ();
516    type Error = SklearsError;
517    type Float = Float;
518
519    fn config(&self) -> &Self::Config {
520        &()
521    }
522}
523
524impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>>
525    for VariationalBayesianSemiSupervised<Untrained>
526{
527    type Fitted = VariationalBayesianSemiSupervised<VariationalBayesianTrained>;
528
529    #[allow(non_snake_case)]
530    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
531        let X = X.to_owned();
532        let y = y.to_owned();
533        let (n_samples, n_features) = X.dim();
534
535        // Identify labeled and unlabeled samples
536        let mut labeled_indices = Vec::new();
537        let mut unlabeled_indices = Vec::new();
538        let mut classes = std::collections::HashSet::new();
539
540        for (i, &label) in y.iter().enumerate() {
541            if label == -1 {
542                unlabeled_indices.push(i);
543            } else {
544                labeled_indices.push(i);
545                classes.insert(label);
546            }
547        }
548
549        if labeled_indices.is_empty() {
550            return Err(SklearsError::InvalidInput(
551                "No labeled samples provided".to_string(),
552            ));
553        }
554
555        let classes: Vec<i32> = classes.into_iter().collect();
556
557        // Initialize random number generator
558        let mut rng = if let Some(seed) = self.random_state {
559            Random::seed(seed)
560        } else {
561            Random::seed(
562                std::time::SystemTime::now()
563                    .duration_since(std::time::UNIX_EPOCH)
564                    .unwrap()
565                    .as_secs(),
566            )
567        };
568
569        // Initialize variational parameters
570        let mut means = Array2::<f64>::zeros((self.n_components, n_features));
571        let mut covariances = Vec::new();
572        let mut mixing_weights = Array1::<f64>::ones(self.n_components) / self.n_components as f64;
573
574        // Initialize means randomly
575        for k in 0..self.n_components {
576            for j in 0..n_features {
577                means[[k, j]] = rng.random_range(-1.0..1.0);
578            }
579        }
580
581        // Initialize covariances as identity matrices
582        for _k in 0..self.n_components {
583            let mut cov = Array2::<f64>::zeros((n_features, n_features));
584            for i in 0..n_features {
585                cov[[i, i]] = 1.0 + self.reg_covar;
586            }
587            covariances.push(cov);
588        }
589
590        // Variational EM algorithm
591        let mut responsibilities = Array2::<f64>::zeros((n_samples, self.n_components));
592
593        for _iter in 0..self.max_iter {
594            let prev_means = means.clone();
595
596            // E-step: Update responsibilities
597            for i in 0..n_samples {
598                let mut log_prob_norm = f64::NEG_INFINITY;
599
600                // Compute log probabilities for each component
601                for k in 0..self.n_components {
602                    let log_prob =
603                        self.compute_log_probability(&X.row(i), &means.row(k), &covariances[k]);
604                    let log_resp = log_prob + mixing_weights[k].ln();
605
606                    if log_resp > log_prob_norm {
607                        log_prob_norm = log_resp;
608                    }
609                }
610
611                // Compute responsibilities using log-sum-exp trick
612                let mut exp_sum = 0.0;
613                for k in 0..self.n_components {
614                    let log_prob =
615                        self.compute_log_probability(&X.row(i), &means.row(k), &covariances[k]);
616                    let log_resp = log_prob + mixing_weights[k].ln() - log_prob_norm;
617                    responsibilities[[i, k]] = log_resp.exp();
618                    exp_sum += responsibilities[[i, k]];
619                }
620
621                // Normalize responsibilities
622                if exp_sum > 0.0 {
623                    for k in 0..self.n_components {
624                        responsibilities[[i, k]] /= exp_sum;
625                    }
626                }
627            }
628
629            // M-step: Update parameters
630            for k in 0..self.n_components {
631                let n_k: f64 = responsibilities.column(k).sum();
632
633                if n_k > 1e-10 {
634                    // Update means
635                    let mut new_mean = Array1::<f64>::zeros(n_features);
636                    for i in 0..n_samples {
637                        for j in 0..n_features {
638                            new_mean[j] += responsibilities[[i, k]] * X[[i, j]];
639                        }
640                    }
641                    new_mean /= n_k;
642                    means.row_mut(k).assign(&new_mean);
643
644                    // Update mixing weights
645                    mixing_weights[k] = (n_k + self.alpha_prior - 1.0)
646                        / (n_samples as f64 + self.n_components as f64 * self.alpha_prior
647                            - self.n_components as f64);
648                }
649            }
650
651            // Check convergence
652            let diff = (&means - &prev_means).mapv(|x| x.abs()).sum();
653            if diff < self.tol {
654                break;
655            }
656        }
657
658        // Predict labels for unlabeled samples
659        let mut final_labels = y.clone();
660        for &idx in &unlabeled_indices {
661            // Assign to component with highest responsibility
662            let best_component = responsibilities
663                .row(idx)
664                .iter()
665                .enumerate()
666                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
667                .unwrap()
668                .0;
669
670            // Map component to class (simplified)
671            let predicted_class = classes[best_component % classes.len()];
672            final_labels[idx] = predicted_class;
673        }
674
675        Ok(VariationalBayesianSemiSupervised {
676            state: VariationalBayesianTrained {
677                X_train: X,
678                y_train: final_labels,
679                classes: Array1::from(classes),
680                means,
681                covariances,
682                mixing_weights,
683                responsibilities,
684            },
685            n_components: self.n_components,
686            max_iter: self.max_iter,
687            tol: self.tol,
688            reg_covar: self.reg_covar,
689            alpha_prior: self.alpha_prior,
690            random_state: self.random_state,
691        })
692    }
693}
694
695impl VariationalBayesianSemiSupervised<Untrained> {
696    /// Compute log probability of a sample under a Gaussian distribution
697    fn compute_log_probability(
698        &self,
699        x: &ArrayView1<f64>,
700        mean: &ArrayView1<f64>,
701        covariance: &Array2<f64>,
702    ) -> f64 {
703        let d = x.len() as f64;
704        let diff = x.to_owned() - mean.to_owned();
705
706        // Simplified: assume diagonal covariance for computational efficiency
707        let mut log_prob = -0.5 * d * (2.0 * PI).ln();
708
709        for i in 0..diff.len() {
710            let var = covariance[[i, i]];
711            log_prob -= 0.5 * (var.ln() + diff[i] * diff[i] / var);
712        }
713
714        log_prob
715    }
716}
717
718impl Predict<ArrayView2<'_, Float>, Array1<i32>>
719    for VariationalBayesianSemiSupervised<VariationalBayesianTrained>
720{
721    #[allow(non_snake_case)]
722    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
723        let X = X.to_owned();
724        let n_test = X.nrows();
725        let mut predictions = Array1::zeros(n_test);
726
727        for i in 0..n_test {
728            // Find most likely component
729            let mut max_log_prob = f64::NEG_INFINITY;
730            let mut best_component = 0;
731
732            for k in 0..self.n_components {
733                let log_prob = self.compute_log_probability(
734                    &X.row(i),
735                    &self.state.means.row(k),
736                    &self.state.covariances[k],
737                ) + self.state.mixing_weights[k].ln();
738
739                if log_prob > max_log_prob {
740                    max_log_prob = log_prob;
741                    best_component = k;
742                }
743            }
744
745            // Map component to class
746            let predicted_class = self.state.classes[best_component % self.state.classes.len()];
747            predictions[i] = predicted_class;
748        }
749
750        Ok(predictions)
751    }
752}
753
754impl VariationalBayesianSemiSupervised<VariationalBayesianTrained> {
755    /// Compute log probability of a sample under a Gaussian distribution
756    fn compute_log_probability(
757        &self,
758        x: &ArrayView1<f64>,
759        mean: &ArrayView1<f64>,
760        covariance: &Array2<f64>,
761    ) -> f64 {
762        let d = x.len() as f64;
763        let diff = x.to_owned() - mean.to_owned();
764
765        // Simplified: assume diagonal covariance
766        let mut log_prob = -0.5 * d * (2.0 * PI).ln();
767
768        for i in 0..diff.len() {
769            let var = covariance[[i, i]];
770            log_prob -= 0.5 * (var.ln() + diff[i] * diff[i] / var);
771        }
772
773        log_prob
774    }
775}
776
777/// Trained state for GaussianProcessSemiSupervised
778#[derive(Debug, Clone)]
779pub struct GaussianProcessTrained {
780    /// X_train
781    pub X_train: Array2<f64>,
782    /// y_train
783    pub y_train: Array1<i32>,
784    /// classes
785    pub classes: Array1<i32>,
786    /// X_labeled
787    pub X_labeled: Array2<f64>,
788    /// GP_weights
789    pub GP_weights: Array2<f64>,
790    /// predictions_all
791    pub predictions_all: Array2<f64>,
792}
793
794/// Trained state for VariationalBayesianSemiSupervised
795#[derive(Debug, Clone)]
796pub struct VariationalBayesianTrained {
797    /// X_train
798    pub X_train: Array2<f64>,
799    /// y_train
800    pub y_train: Array1<i32>,
801    /// classes
802    pub classes: Array1<i32>,
803    /// means
804    pub means: Array2<f64>,
805    /// covariances
806    pub covariances: Vec<Array2<f64>>,
807    /// mixing_weights
808    pub mixing_weights: Array1<f64>,
809    /// responsibilities
810    pub responsibilities: Array2<f64>,
811}
812
813/// Bayesian Active Learning for Semi-Supervised Learning
814///
815/// This method uses Bayesian inference to select the most informative
816/// unlabeled samples for labeling based on prediction uncertainty.
817///
818/// # Parameters
819///
820/// * `n_queries` - Number of samples to query
821/// * `kernel` - Kernel function for GP
822/// * `length_scale` - Length scale for RBF kernel
823/// * `noise_level` - Noise level for GP
824/// * `acquisition` - Acquisition function ('uncertainty', 'entropy')
825#[derive(Debug, Clone)]
826pub struct BayesianActiveLearning<S = Untrained> {
827    state: S,
828    n_queries: usize,
829    kernel: String,
830    length_scale: f64,
831    noise_level: f64,
832    acquisition: String,
833    random_state: Option<u64>,
834}
835
836impl BayesianActiveLearning<Untrained> {
837    /// Create a new BayesianActiveLearning instance
838    pub fn new() -> Self {
839        Self {
840            state: Untrained,
841            n_queries: 10,
842            kernel: "rbf".to_string(),
843            length_scale: 1.0,
844            noise_level: 0.1,
845            acquisition: "uncertainty".to_string(),
846            random_state: None,
847        }
848    }
849
850    /// Set the number of queries
851    pub fn n_queries(mut self, n_queries: usize) -> Self {
852        self.n_queries = n_queries;
853        self
854    }
855
856    /// Set the kernel function
857    pub fn kernel(mut self, kernel: String) -> Self {
858        self.kernel = kernel;
859        self
860    }
861
862    /// Set the length scale
863    pub fn length_scale(mut self, length_scale: f64) -> Self {
864        self.length_scale = length_scale;
865        self
866    }
867
868    /// Set the noise level
869    pub fn noise_level(mut self, noise_level: f64) -> Self {
870        self.noise_level = noise_level;
871        self
872    }
873
874    /// Set the acquisition function
875    pub fn acquisition(mut self, acquisition: String) -> Self {
876        self.acquisition = acquisition;
877        self
878    }
879
880    /// Set the random state
881    pub fn random_state(mut self, random_state: u64) -> Self {
882        self.random_state = Some(random_state);
883        self
884    }
885}
886
887impl Default for BayesianActiveLearning<Untrained> {
888    fn default() -> Self {
889        Self::new()
890    }
891}
892
893impl Estimator for BayesianActiveLearning<Untrained> {
894    type Config = ();
895    type Error = SklearsError;
896    type Float = Float;
897
898    fn config(&self) -> &Self::Config {
899        &()
900    }
901}
902
903impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for BayesianActiveLearning<Untrained> {
904    type Fitted = BayesianActiveLearning<BayesianActiveTrained>;
905
906    #[allow(non_snake_case)]
907    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
908        let X = X.to_owned();
909        let y = y.to_owned();
910
911        // Identify labeled and unlabeled samples
912        let mut labeled_indices = Vec::new();
913        let mut unlabeled_indices = Vec::new();
914        let mut classes = std::collections::HashSet::new();
915
916        for (i, &label) in y.iter().enumerate() {
917            if label == -1 {
918                unlabeled_indices.push(i);
919            } else {
920                labeled_indices.push(i);
921                classes.insert(label);
922            }
923        }
924
925        if labeled_indices.is_empty() {
926            return Err(SklearsError::InvalidInput(
927                "No labeled samples provided".to_string(),
928            ));
929        }
930
931        let classes: Vec<i32> = classes.into_iter().collect();
932
933        // Compute uncertainties for all unlabeled samples
934        let mut uncertainties = Vec::new();
935        for &idx in &unlabeled_indices {
936            // Simple uncertainty: distance to nearest labeled point
937            let mut min_dist = f64::INFINITY;
938            for &labeled_idx in &labeled_indices {
939                let diff = &X.row(idx) - &X.row(labeled_idx);
940                let dist = diff.mapv(|x| x * x).sum().sqrt();
941                if dist < min_dist {
942                    min_dist = dist;
943                }
944            }
945            uncertainties.push((idx, min_dist));
946        }
947
948        // Sort by uncertainty and select top queries
949        uncertainties.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
950        let query_indices: Vec<usize> = uncertainties
951            .iter()
952            .take(self.n_queries.min(unlabeled_indices.len()))
953            .map(|(idx, _)| *idx)
954            .collect();
955
956        Ok(BayesianActiveLearning {
957            state: BayesianActiveTrained {
958                X_train: X,
959                y_train: y,
960                classes: Array1::from(classes),
961                query_indices,
962                uncertainties: uncertainties.iter().map(|(_, u)| *u).collect(),
963            },
964            n_queries: self.n_queries,
965            kernel: self.kernel,
966            length_scale: self.length_scale,
967            noise_level: self.noise_level,
968            acquisition: self.acquisition,
969            random_state: self.random_state,
970        })
971    }
972}
973
974impl Predict<ArrayView2<'_, Float>, Array1<i32>> for BayesianActiveLearning<BayesianActiveTrained> {
975    #[allow(non_snake_case)]
976    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
977        let X = X.to_owned();
978        let n_test = X.nrows();
979        let mut predictions = Array1::zeros(n_test);
980
981        // Simple nearest neighbor prediction
982        for i in 0..n_test {
983            let mut min_dist = f64::INFINITY;
984            let mut best_label = self.state.classes[0];
985
986            for j in 0..self.state.X_train.nrows() {
987                if self.state.y_train[j] != -1 {
988                    let diff = &X.row(i) - &self.state.X_train.row(j);
989                    let dist = diff.mapv(|x| x * x).sum().sqrt();
990
991                    if dist < min_dist {
992                        min_dist = dist;
993                        best_label = self.state.y_train[j];
994                    }
995                }
996            }
997
998            predictions[i] = best_label;
999        }
1000
1001        Ok(predictions)
1002    }
1003}
1004
1005/// Hierarchical Bayesian Semi-Supervised Learning
1006///
1007/// This method uses hierarchical Bayesian modeling to learn from both
1008/// labeled and unlabeled data with multiple levels of hierarchy.
1009///
1010/// # Parameters
1011///
1012/// * `n_levels` - Number of hierarchy levels
1013/// * `n_components` - Number of components per level
1014/// * `max_iter` - Maximum number of iterations
1015/// * `prior_strength` - Strength of hierarchical prior
1016#[derive(Debug, Clone)]
1017pub struct HierarchicalBayesianSemiSupervised<S = Untrained> {
1018    state: S,
1019    n_levels: usize,
1020    n_components: usize,
1021    max_iter: usize,
1022    prior_strength: f64,
1023    random_state: Option<u64>,
1024}
1025
1026impl HierarchicalBayesianSemiSupervised<Untrained> {
1027    /// Create a new HierarchicalBayesianSemiSupervised instance
1028    pub fn new() -> Self {
1029        Self {
1030            state: Untrained,
1031            n_levels: 2,
1032            n_components: 2,
1033            max_iter: 100,
1034            prior_strength: 1.0,
1035            random_state: None,
1036        }
1037    }
1038
1039    /// Set the number of levels
1040    pub fn n_levels(mut self, n_levels: usize) -> Self {
1041        self.n_levels = n_levels;
1042        self
1043    }
1044
1045    /// Set the number of components
1046    pub fn n_components(mut self, n_components: usize) -> Self {
1047        self.n_components = n_components;
1048        self
1049    }
1050
1051    /// Set the maximum number of iterations
1052    pub fn max_iter(mut self, max_iter: usize) -> Self {
1053        self.max_iter = max_iter;
1054        self
1055    }
1056
1057    /// Set the prior strength
1058    pub fn prior_strength(mut self, prior_strength: f64) -> Self {
1059        self.prior_strength = prior_strength;
1060        self
1061    }
1062
1063    /// Set the random state
1064    pub fn random_state(mut self, random_state: u64) -> Self {
1065        self.random_state = Some(random_state);
1066        self
1067    }
1068}
1069
1070impl Default for HierarchicalBayesianSemiSupervised<Untrained> {
1071    fn default() -> Self {
1072        Self::new()
1073    }
1074}
1075
1076impl Estimator for HierarchicalBayesianSemiSupervised<Untrained> {
1077    type Config = ();
1078    type Error = SklearsError;
1079    type Float = Float;
1080
1081    fn config(&self) -> &Self::Config {
1082        &()
1083    }
1084}
1085
1086impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>>
1087    for HierarchicalBayesianSemiSupervised<Untrained>
1088{
1089    type Fitted = HierarchicalBayesianSemiSupervised<HierarchicalBayesianTrained>;
1090
1091    #[allow(non_snake_case)]
1092    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
1093        let X = X.to_owned();
1094        let y = y.to_owned();
1095        let (n_samples, n_features) = X.dim();
1096
1097        // Identify labeled and unlabeled samples
1098        let mut labeled_indices = Vec::new();
1099        let mut unlabeled_indices = Vec::new();
1100        let mut classes = std::collections::HashSet::new();
1101
1102        for (i, &label) in y.iter().enumerate() {
1103            if label == -1 {
1104                unlabeled_indices.push(i);
1105            } else {
1106                labeled_indices.push(i);
1107                classes.insert(label);
1108            }
1109        }
1110
1111        if labeled_indices.is_empty() {
1112            return Err(SklearsError::InvalidInput(
1113                "No labeled samples provided".to_string(),
1114            ));
1115        }
1116
1117        let classes: Vec<i32> = classes.into_iter().collect();
1118        let n_classes = classes.len();
1119
1120        // Initialize random number generator
1121        let mut rng = if let Some(seed) = self.random_state {
1122            Random::seed(seed)
1123        } else {
1124            Random::seed(
1125                std::time::SystemTime::now()
1126                    .duration_since(std::time::UNIX_EPOCH)
1127                    .unwrap()
1128                    .as_secs(),
1129            )
1130        };
1131
1132        // Initialize hierarchical parameters
1133        let mut level_means = Vec::new();
1134        for _ in 0..self.n_levels {
1135            let mut means = Array2::<f64>::zeros((self.n_components, n_features));
1136            for i in 0..self.n_components {
1137                for j in 0..n_features {
1138                    means[[i, j]] = rng.random_range(-1.0..1.0);
1139                }
1140            }
1141            level_means.push(means);
1142        }
1143
1144        // Simple hierarchical optimization
1145        for _iter in 0..self.max_iter {
1146            // Update lower levels based on data
1147            // Simplified: just average the data points
1148            let mean = X.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
1149            #[allow(clippy::needless_range_loop)]
1150            for level_idx in 0..self.n_levels {
1151                for comp_idx in 0..self.n_components {
1152                    for feat_idx in 0..n_features {
1153                        level_means[level_idx][[comp_idx, feat_idx]] = 0.9
1154                            * level_means[level_idx][[comp_idx, feat_idx]]
1155                            + 0.1 * mean[feat_idx];
1156                    }
1157                }
1158            }
1159        }
1160
1161        // Predict labels for unlabeled samples
1162        let mut final_labels = y.clone();
1163        for &idx in &unlabeled_indices {
1164            // Find nearest component in lowest level
1165            let mut min_dist = f64::INFINITY;
1166            let mut best_component = 0;
1167
1168            for comp_idx in 0..self.n_components {
1169                let diff = &X.row(idx) - &level_means[0].row(comp_idx);
1170                let dist = diff.mapv(|x| x * x).sum().sqrt();
1171                if dist < min_dist {
1172                    min_dist = dist;
1173                    best_component = comp_idx;
1174                }
1175            }
1176
1177            // Map component to class
1178            let predicted_class = classes[best_component % n_classes];
1179            final_labels[idx] = predicted_class;
1180        }
1181
1182        Ok(HierarchicalBayesianSemiSupervised {
1183            state: HierarchicalBayesianTrained {
1184                X_train: X,
1185                y_train: final_labels,
1186                classes: Array1::from(classes),
1187                level_means,
1188            },
1189            n_levels: self.n_levels,
1190            n_components: self.n_components,
1191            max_iter: self.max_iter,
1192            prior_strength: self.prior_strength,
1193            random_state: self.random_state,
1194        })
1195    }
1196}
1197
1198impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1199    for HierarchicalBayesianSemiSupervised<HierarchicalBayesianTrained>
1200{
1201    #[allow(non_snake_case)]
1202    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1203        let X = X.to_owned();
1204        let n_test = X.nrows();
1205        let mut predictions = Array1::zeros(n_test);
1206
1207        for i in 0..n_test {
1208            // Find nearest component in lowest level
1209            let mut min_dist = f64::INFINITY;
1210            let mut best_label = self.state.classes[0];
1211
1212            for j in 0..self.state.X_train.nrows() {
1213                let diff = &X.row(i) - &self.state.X_train.row(j);
1214                let dist = diff.mapv(|x| x * x).sum().sqrt();
1215
1216                if dist < min_dist {
1217                    min_dist = dist;
1218                    best_label = self.state.y_train[j];
1219                }
1220            }
1221
1222            predictions[i] = best_label;
1223        }
1224
1225        Ok(predictions)
1226    }
1227}
1228
1229/// Trained state for BayesianActiveLearning
1230#[derive(Debug, Clone)]
1231pub struct BayesianActiveTrained {
1232    /// X_train
1233    pub X_train: Array2<f64>,
1234    /// y_train
1235    pub y_train: Array1<i32>,
1236    /// classes
1237    pub classes: Array1<i32>,
1238    /// query_indices
1239    pub query_indices: Vec<usize>,
1240    /// uncertainties
1241    pub uncertainties: Vec<f64>,
1242}
1243
1244/// Trained state for HierarchicalBayesianSemiSupervised
1245#[derive(Debug, Clone)]
1246pub struct HierarchicalBayesianTrained {
1247    /// X_train
1248    pub X_train: Array2<f64>,
1249    /// y_train
1250    pub y_train: Array1<i32>,
1251    /// classes
1252    pub classes: Array1<i32>,
1253    /// level_means
1254    pub level_means: Vec<Array2<f64>>,
1255}
1256
1257#[allow(non_snake_case)]
1258#[cfg(test)]
1259mod tests {
1260    use super::*;
1261    use scirs2_core::array;
1262
1263    #[test]
1264    #[allow(non_snake_case)]
1265    fn test_gaussian_process_semi_supervised() {
1266        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1267        let y = array![0, 1, -1, -1];
1268
1269        let gp = GaussianProcessSemiSupervised::new()
1270            .kernel("rbf".to_string())
1271            .length_scale(1.0)
1272            .noise_level(0.1)
1273            .random_state(42);
1274
1275        let fitted = gp.fit(&X.view(), &y.view()).unwrap();
1276        let predictions = fitted.predict(&X.view()).unwrap();
1277        let probas = fitted.predict_proba(&X.view()).unwrap();
1278
1279        assert_eq!(predictions.len(), 4);
1280        assert_eq!(probas.dim(), (4, 2));
1281        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1282
1283        // Check that probabilities sum to 1
1284        for i in 0..4 {
1285            let sum: f64 = probas.row(i).sum();
1286            assert!((sum - 1.0).abs() < 1e-6);
1287        }
1288
1289        // Check that labeled samples are predicted correctly
1290        assert_eq!(predictions[0], 0);
1291        assert_eq!(predictions[1], 1);
1292    }
1293
1294    #[test]
1295    #[allow(non_snake_case)]
1296    fn test_variational_bayesian_semi_supervised() {
1297        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1298        let y = array![0, 1, -1, -1];
1299
1300        let vb = VariationalBayesianSemiSupervised::new()
1301            .n_components(2)
1302            .max_iter(10)
1303            .random_state(42);
1304
1305        let fitted = vb.fit(&X.view(), &y.view()).unwrap();
1306        let predictions = fitted.predict(&X.view()).unwrap();
1307
1308        assert_eq!(predictions.len(), 4);
1309        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1310    }
1311
1312    #[test]
1313    fn test_gaussian_process_parameters() {
1314        let gp = GaussianProcessSemiSupervised::new()
1315            .kernel("linear".to_string())
1316            .length_scale(2.0)
1317            .noise_level(0.2)
1318            .alpha(1e-8)
1319            .n_restarts_optimizer(5);
1320
1321        assert_eq!(gp.kernel, "linear");
1322        assert_eq!(gp.length_scale, 2.0);
1323        assert_eq!(gp.noise_level, 0.2);
1324        assert_eq!(gp.alpha, 1e-8);
1325        assert_eq!(gp.n_restarts_optimizer, 5);
1326    }
1327
1328    #[test]
1329    fn test_variational_bayesian_parameters() {
1330        let vb = VariationalBayesianSemiSupervised::new()
1331            .n_components(4)
1332            .max_iter(200)
1333            .tol(1e-6)
1334            .reg_covar(1e-4)
1335            .alpha_prior(2.0);
1336
1337        assert_eq!(vb.n_components, 4);
1338        assert_eq!(vb.max_iter, 200);
1339        assert_eq!(vb.tol, 1e-6);
1340        assert_eq!(vb.reg_covar, 1e-4);
1341        assert_eq!(vb.alpha_prior, 2.0);
1342    }
1343
1344    #[test]
1345    #[allow(non_snake_case)]
1346    fn test_kernel_matrix_computation() {
1347        let gp = GaussianProcessSemiSupervised::new()
1348            .kernel("rbf".to_string())
1349            .length_scale(1.0);
1350        let X = array![[1.0, 2.0], [3.0, 4.0]];
1351
1352        let K = gp.compute_kernel_matrix(&X, &X).unwrap();
1353
1354        assert_eq!(K.dim(), (2, 2));
1355        assert!((K[[0, 0]] - 1.0).abs() < 1e-10); // Self-kernel should be 1
1356        assert!((K[[1, 1]] - 1.0).abs() < 1e-10);
1357        assert!(K[[0, 1]] > 0.0 && K[[0, 1]] < 1.0); // Cross-kernel should be between 0 and 1
1358        assert!((K[[0, 1]] - K[[1, 0]]).abs() < 1e-10); // Should be symmetric
1359    }
1360
1361    #[test]
1362    #[allow(non_snake_case)]
1363    fn test_linear_kernel() {
1364        let gp = GaussianProcessSemiSupervised::new().kernel("linear".to_string());
1365        let X = array![[1.0, 2.0], [3.0, 4.0]];
1366
1367        let K = gp.compute_kernel_matrix(&X, &X).unwrap();
1368
1369        assert_eq!(K.dim(), (2, 2));
1370        assert!((K[[0, 0]] - 5.0).abs() < 1e-10); // 1*1 + 2*2 = 5
1371        assert!((K[[1, 1]] - 25.0).abs() < 1e-10); // 3*3 + 4*4 = 25
1372        assert!((K[[0, 1]] - 11.0).abs() < 1e-10); // 1*3 + 2*4 = 11
1373    }
1374
1375    #[test]
1376    #[allow(non_snake_case)]
1377    fn test_polynomial_kernel() {
1378        let gp = GaussianProcessSemiSupervised::new()
1379            .kernel("polynomial".to_string())
1380            .length_scale(1.0);
1381        let X = array![[1.0, 1.0]];
1382
1383        let K = gp.compute_kernel_matrix(&X, &X).unwrap();
1384
1385        assert_eq!(K.dim(), (1, 1));
1386        assert!((K[[0, 0]] - 9.0).abs() < 1e-10); // (1 + 1*1 + 1*1)^2 = 3^2 = 9
1387    }
1388
1389    #[test]
1390    #[allow(non_snake_case)]
1391    fn test_empty_labeled_samples_error() {
1392        let X = array![[1.0, 2.0], [2.0, 3.0]];
1393        let y = array![-1, -1]; // No labeled samples
1394
1395        let gp = GaussianProcessSemiSupervised::new();
1396        let result = gp.fit(&X.view(), &y.view());
1397
1398        assert!(result.is_err());
1399
1400        let vb = VariationalBayesianSemiSupervised::new();
1401        let result = vb.fit(&X.view(), &y.view());
1402
1403        assert!(result.is_err());
1404    }
1405
1406    #[test]
1407    #[allow(non_snake_case)]
1408    fn test_single_labeled_sample() {
1409        let X = array![[1.0, 2.0], [2.0, 3.0]];
1410        let y = array![0, -1]; // One labeled sample
1411
1412        let gp = GaussianProcessSemiSupervised::new()
1413            .noise_level(0.1)
1414            .random_state(42);
1415
1416        let fitted = gp.fit(&X.view(), &y.view()).unwrap();
1417        let predictions = fitted.predict(&X.view()).unwrap();
1418
1419        assert_eq!(predictions.len(), 2);
1420        assert_eq!(predictions[0], 0); // Labeled sample should be correct
1421    }
1422
1423    #[test]
1424    fn test_log_probability_computation() {
1425        let vb = VariationalBayesianSemiSupervised::new();
1426        let x = array![1.0, 2.0];
1427        let mean = array![1.0, 2.0];
1428        let mut covar = Array2::<f64>::zeros((2, 2));
1429        covar[[0, 0]] = 1.0;
1430        covar[[1, 1]] = 1.0;
1431
1432        let log_prob = vb.compute_log_probability(&x.view(), &mean.view(), &covar);
1433
1434        // For x = mean, the probability should be maximized
1435        assert!(log_prob.is_finite());
1436        assert!(log_prob < 0.0); // Log probability should be negative
1437    }
1438
1439    #[test]
1440    #[allow(non_snake_case)]
1441    fn test_bayesian_active_learning() {
1442        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1443        let y = array![0, 1, -1, -1];
1444
1445        let bal = BayesianActiveLearning::new().n_queries(2).random_state(42);
1446
1447        let fitted = bal.fit(&X.view(), &y.view()).unwrap();
1448        let predictions = fitted.predict(&X.view()).unwrap();
1449
1450        assert_eq!(predictions.len(), 4);
1451        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1452        assert_eq!(fitted.state.query_indices.len(), 2);
1453    }
1454
1455    #[test]
1456    fn test_bayesian_active_learning_parameters() {
1457        let bal = BayesianActiveLearning::new()
1458            .n_queries(5)
1459            .kernel("rbf".to_string())
1460            .length_scale(2.0)
1461            .noise_level(0.2)
1462            .acquisition("entropy".to_string());
1463
1464        assert_eq!(bal.n_queries, 5);
1465        assert_eq!(bal.kernel, "rbf");
1466        assert_eq!(bal.length_scale, 2.0);
1467        assert_eq!(bal.noise_level, 0.2);
1468        assert_eq!(bal.acquisition, "entropy");
1469    }
1470
1471    #[test]
1472    #[allow(non_snake_case)]
1473    fn test_hierarchical_bayesian() {
1474        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1475        let y = array![0, 1, -1, -1];
1476
1477        let hb = HierarchicalBayesianSemiSupervised::new()
1478            .n_levels(2)
1479            .n_components(2)
1480            .max_iter(10)
1481            .random_state(42);
1482
1483        let fitted = hb.fit(&X.view(), &y.view()).unwrap();
1484        let predictions = fitted.predict(&X.view()).unwrap();
1485
1486        assert_eq!(predictions.len(), 4);
1487        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1488        assert_eq!(fitted.state.level_means.len(), 2);
1489    }
1490
1491    #[test]
1492    fn test_hierarchical_bayesian_parameters() {
1493        let hb = HierarchicalBayesianSemiSupervised::new()
1494            .n_levels(3)
1495            .n_components(4)
1496            .max_iter(200)
1497            .prior_strength(2.0);
1498
1499        assert_eq!(hb.n_levels, 3);
1500        assert_eq!(hb.n_components, 4);
1501        assert_eq!(hb.max_iter, 200);
1502        assert_eq!(hb.prior_strength, 2.0);
1503    }
1504}