sklears_multioutput/
utils.rs

1//! Utility functions and shared types for multi-output learning
2//!
3//! This module contains common functionality used across multiple algorithms in the
4//! multi-output learning suite, including mathematical operations, binary classifier
5//! training, feature processing, and label reconstruction methods.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    types::Float,
12};
13
14/// Methods for label reconstruction in compressed sensing approaches
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum ReconstructionMethod {
17    /// Linear reconstruction using least squares
18    Linear,
19    /// Iterative soft thresholding
20    IterativeThresholding,
21    /// Orthogonal matching pursuit
22    OrthogonalMatchingPursuit,
23}
24
25/// Strategies for pruning label combinations in label powerset methods
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum PruningStrategy {
28    /// Remove rare combinations and map to default (e.g., all zeros)
29    Default,
30    /// Map rare combinations to most similar frequent combination
31    Similarity,
32}
33
34/// Classification criteria for decision trees
35#[derive(Debug, Clone, Copy, PartialEq)]
36pub enum ClassificationCriterion {
37    /// Gini impurity
38    Gini,
39    /// Information gain (entropy)
40    Entropy,
41}
42
43/// Threshold strategies for binary relevance methods
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub enum ThresholdStrategy {
46    /// Fixed threshold for all labels
47    Fixed,
48    /// Per-label optimal thresholds
49    PerLabel,
50    /// Optimal thresholds based on validation data
51    Optimal,
52    /// F-score based thresholds
53    FScore,
54}
55
56/// Calibration methods for probability calibration
57#[derive(Debug, Clone, Copy, PartialEq)]
58pub enum CalibrationMethod {
59    /// Sigmoid calibration (Platt scaling)
60    Sigmoid,
61    /// Isotonic regression calibration
62    Isotonic,
63}
64
65/// Simple linear classifier for binary classification
66#[derive(Debug, Clone)]
67pub struct SimpleLinearClassifier {
68    /// Weight vector
69    pub weights: Array1<Float>,
70    /// Bias term
71    pub bias: Float,
72}
73
74/// Simple binary classification model
75#[derive(Debug, Clone)]
76pub struct SimpleBinaryModel {
77    /// Feature weights
78    pub weights: Array1<Float>,
79    /// Bias term
80    pub bias: Float,
81    /// Training accuracy
82    pub accuracy: Float,
83}
84
85/// Bayesian binary classification model
86#[derive(Debug, Clone)]
87pub struct BayesianBinaryModel {
88    /// Posterior mean of weights
89    pub weight_mean: Array1<Float>,
90    /// Posterior covariance of weights
91    pub weight_cov: Array2<Float>,
92    /// Bias parameter
93    pub bias_mean: Float,
94    /// Bias variance
95    pub bias_var: Float,
96    /// Noise precision
97    pub noise_precision: Float,
98}
99
100/// Cost matrix for cost-sensitive learning
101#[derive(Debug, Clone)]
102pub struct CostMatrix {
103    /// False positive costs for each label
104    pub fp_costs: Vec<Float>,
105    /// False negative costs for each label
106    pub fn_costs: Vec<Float>,
107}
108
109impl CostMatrix {
110    /// Create cost matrix from false positive and false negative costs
111    pub fn from_fp_fn_costs(fp_costs: Vec<Float>, fn_costs: Vec<Float>) -> SklResult<Self> {
112        if fp_costs.len() != fn_costs.len() {
113            return Err(SklearsError::InvalidInput(
114                "False positive and false negative cost vectors must have the same length"
115                    .to_string(),
116            ));
117        }
118
119        if fp_costs.is_empty() {
120            return Err(SklearsError::InvalidInput(
121                "Cost vectors cannot be empty".to_string(),
122            ));
123        }
124
125        // Check for non-negative costs
126        for &cost in fp_costs.iter().chain(fn_costs.iter()) {
127            if cost < 0.0 {
128                return Err(SklearsError::InvalidInput(
129                    "All costs must be non-negative".to_string(),
130                ));
131            }
132        }
133
134        Ok(Self { fp_costs, fn_costs })
135    }
136
137    /// Create balanced cost matrix (all costs = 1.0)
138    pub fn balanced(n_labels: usize) -> Self {
139        Self {
140            fp_costs: vec![1.0; n_labels],
141            fn_costs: vec![1.0; n_labels],
142        }
143    }
144
145    /// Get optimal threshold for a given label based on cost ratio
146    pub fn get_threshold(&self, label_idx: usize) -> Float {
147        if label_idx >= self.fp_costs.len() {
148            return 0.5; // Default threshold
149        }
150
151        let fp_cost = self.fp_costs[label_idx];
152        let fn_cost = self.fn_costs[label_idx];
153
154        // Optimal threshold = FP_cost / (FP_cost + FN_cost)
155        if fp_cost + fn_cost > 0.0 {
156            fp_cost / (fp_cost + fn_cost)
157        } else {
158            0.5
159        }
160    }
161}
162
163/// Calculate Euclidean distance between two points
164pub fn euclidean_distance(x1: &ArrayView1<Float>, x2: &ArrayView1<Float>) -> Float {
165    x1.iter()
166        .zip(x2.iter())
167        .map(|(a, b)| (a - b).powi(2))
168        .sum::<Float>()
169        .sqrt()
170}
171
172/// Standardize features using provided means and standard deviations
173pub fn standardize_features_simple(
174    X: &ArrayView2<Float>,
175    means: &Array1<Float>,
176    stds: &Array1<Float>,
177) -> Array2<Float> {
178    let mut X_standardized = X.to_owned();
179
180    for (mut col, (&mean, &std)) in X_standardized
181        .axis_iter_mut(Axis(1))
182        .zip(means.iter().zip(stds.iter()))
183    {
184        col.mapv_inplace(|x| (x - mean) / std);
185    }
186
187    X_standardized
188}
189
190/// Train a simple binary classifier using correlation-based approach
191pub fn train_binary_classifier(
192    X: &ArrayView2<Float>,
193    y: &Array1<i32>,
194) -> SklResult<SimpleBinaryModel> {
195    let (n_samples, n_features) = X.dim();
196
197    if n_samples != y.len() {
198        return Err(SklearsError::InvalidInput(
199            "X and y must have the same number of samples".to_string(),
200        ));
201    }
202
203    // Convert labels to Float for computation
204    let y_float: Array1<Float> = y.mapv(|x| x as Float);
205
206    // Calculate means
207    let x_means = X.mean_axis(Axis(0)).unwrap();
208    let y_mean = y_float.mean().unwrap();
209
210    // Calculate weights using correlation
211    let mut weights = Array1::<Float>::zeros(n_features);
212
213    for (i, weight) in weights.iter_mut().enumerate() {
214        let x_col = X.column(i);
215        let x_mean = x_means[i];
216
217        let numerator: Float = x_col
218            .iter()
219            .zip(y_float.iter())
220            .map(|(&x, &y)| (x - x_mean) * (y - y_mean))
221            .sum();
222
223        let denominator: Float = x_col
224            .iter()
225            .map(|&x| (x - x_mean).powi(2))
226            .sum::<Float>()
227            .sqrt()
228            * y_float
229                .iter()
230                .map(|&y| (y - y_mean).powi(2))
231                .sum::<Float>()
232                .sqrt();
233
234        *weight = if denominator > 1e-10 {
235            numerator / denominator
236        } else {
237            0.0
238        };
239    }
240
241    // Calculate bias
242    let bias = y_mean - weights.dot(&x_means);
243
244    // Calculate training accuracy
245    let mut correct = 0;
246    for i in 0..n_samples {
247        let prediction = if weights.dot(&X.row(i)) + bias > 0.0 {
248            1
249        } else {
250            0
251        };
252        if prediction == y[i] {
253            correct += 1;
254        }
255    }
256    let accuracy = correct as Float / n_samples as Float;
257
258    Ok(SimpleBinaryModel {
259        weights,
260        bias,
261        accuracy,
262    })
263}
264
265/// Train a simple linear classifier using least squares
266pub fn train_simple_linear_classifier(
267    X: &ArrayView2<Float>,
268    y: &Array1<Float>,
269) -> SklResult<SimpleLinearClassifier> {
270    let (n_samples, n_features) = X.dim();
271
272    if n_samples != y.len() {
273        return Err(SklearsError::InvalidInput(
274            "X and y must have the same number of samples".to_string(),
275        ));
276    }
277
278    // Add bias column to X
279    let mut X_with_bias = Array2::ones((n_samples, n_features + 1));
280    X_with_bias.slice_mut(s![.., ..n_features]).assign(X);
281
282    // Solve normal equations: (X^T X) w = X^T y
283    let xtx = X_with_bias.t().dot(&X_with_bias);
284    let xty = X_with_bias.t().dot(y);
285
286    let weights_with_bias = solve_linear_system(&xtx, &xty)?;
287
288    let weights = weights_with_bias.slice(s![..n_features]).to_owned();
289    let bias = weights_with_bias[n_features];
290
291    Ok(SimpleLinearClassifier { weights, bias })
292}
293
294/// Predict using a simple linear classifier
295pub fn predict_simple_linear(
296    X: &ArrayView2<Float>,
297    classifier: &SimpleLinearClassifier,
298) -> Array1<Float> {
299    X.dot(&classifier.weights) + classifier.bias
300}
301
302/// Solve linear system Ax = b using Gaussian elimination
303pub fn solve_linear_system(A: &Array2<Float>, b: &Array1<Float>) -> SklResult<Array1<Float>> {
304    let n = A.nrows();
305    if A.ncols() != n || b.len() != n {
306        return Err(SklearsError::InvalidInput(
307            "Matrix must be square and vector must match matrix size".to_string(),
308        ));
309    }
310
311    let mut aug = Array2::<Float>::zeros((n, n + 1));
312    aug.slice_mut(s![.., ..n]).assign(A);
313    aug.slice_mut(s![.., n]).assign(b);
314
315    // Forward elimination
316    for i in 0..n {
317        // Find pivot
318        let mut max_row = i;
319        for k in (i + 1)..n {
320            if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
321                max_row = k;
322            }
323        }
324
325        // Swap rows
326        if max_row != i {
327            for j in 0..=n {
328                let temp = aug[[i, j]];
329                aug[[i, j]] = aug[[max_row, j]];
330                aug[[max_row, j]] = temp;
331            }
332        }
333
334        // Check for singular matrix
335        if aug[[i, i]].abs() < 1e-10 {
336            // Add small regularization
337            aug[[i, i]] += 1e-8;
338        }
339
340        // Eliminate
341        for k in (i + 1)..n {
342            let factor = aug[[k, i]] / aug[[i, i]];
343            for j in i..=n {
344                aug[[k, j]] -= factor * aug[[i, j]];
345            }
346        }
347    }
348
349    // Back substitution
350    let mut x = Array1::<Float>::zeros(n);
351    for i in (0..n).rev() {
352        x[i] = aug[[i, n]];
353        for j in (i + 1)..n {
354            x[i] -= aug[[i, j]] * x[j];
355        }
356        x[i] /= aug[[i, i]];
357    }
358
359    Ok(x)
360}
361
362/// Generate random projection matrix for compressed sensing
363pub fn generate_random_projection_matrix(
364    n_compressed: usize,
365    n_labels: usize,
366    random_state: Option<u64>,
367) -> Array2<Float> {
368    // Use deterministic random generation based on seed
369    let mut rng_state = random_state.unwrap_or(42);
370
371    let mut matrix = Array2::<Float>::zeros((n_compressed, n_labels));
372
373    for i in 0..n_compressed {
374        for j in 0..n_labels {
375            // Simple LCG for reproducible random numbers
376            rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
377            let random_val = (rng_state as Float) / (u64::MAX as Float);
378            matrix[[i, j]] = (random_val - 0.5) * 2.0; // Range [-1, 1]
379        }
380    }
381
382    // Normalize rows
383    for mut row in matrix.rows_mut() {
384        let norm = row.iter().map(|x| x * x).sum::<Float>().sqrt();
385        if norm > 1e-10 {
386            row /= norm;
387        }
388    }
389
390    matrix
391}
392
393/// Reconstruct labels using specified method
394pub fn reconstruct_labels(
395    compressed_labels: &Array1<Float>,
396    projection_matrix: &Array2<Float>,
397    method: ReconstructionMethod,
398) -> SklResult<Array1<Float>> {
399    match method {
400        ReconstructionMethod::Linear => {
401            // Pseudo-inverse reconstruction
402            let pinv = compute_pseudoinverse(projection_matrix)?;
403            Ok(pinv.dot(compressed_labels))
404        }
405        ReconstructionMethod::IterativeThresholding => {
406            iterative_thresholding_reconstruction(compressed_labels, projection_matrix)
407        }
408        ReconstructionMethod::OrthogonalMatchingPursuit => {
409            omp_reconstruction(compressed_labels, projection_matrix)
410        }
411    }
412}
413
414/// Compute pseudo-inverse of a matrix
415fn compute_pseudoinverse(matrix: &Array2<Float>) -> SklResult<Array2<Float>> {
416    let (m, n) = matrix.dim();
417
418    if m >= n {
419        // More rows than columns: (A^T A)^-1 A^T
420        let ata = matrix.t().dot(matrix);
421        let ata_inv = matrix_inverse(&ata)?;
422        Ok(ata_inv.dot(&matrix.t()))
423    } else {
424        // More columns than rows: A^T (A A^T)^-1
425        let aat = matrix.dot(&matrix.t());
426        let aat_inv = matrix_inverse(&aat)?;
427        Ok(matrix.t().dot(&aat_inv))
428    }
429}
430
431/// Compute matrix inverse using Gaussian elimination
432fn matrix_inverse(matrix: &Array2<Float>) -> SklResult<Array2<Float>> {
433    let n = matrix.nrows();
434    if matrix.ncols() != n {
435        return Err(SklearsError::InvalidInput(
436            "Matrix must be square".to_string(),
437        ));
438    }
439
440    let mut aug = Array2::<Float>::zeros((n, 2 * n));
441    aug.slice_mut(s![.., ..n]).assign(matrix);
442
443    // Set up identity on the right side
444    for i in 0..n {
445        aug[[i, n + i]] = 1.0;
446    }
447
448    // Gaussian elimination
449    for i in 0..n {
450        // Find pivot
451        let mut max_row = i;
452        for k in (i + 1)..n {
453            if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
454                max_row = k;
455            }
456        }
457
458        // Swap rows
459        if max_row != i {
460            for j in 0..(2 * n) {
461                let temp = aug[[i, j]];
462                aug[[i, j]] = aug[[max_row, j]];
463                aug[[max_row, j]] = temp;
464            }
465        }
466
467        // Check for singular matrix
468        if aug[[i, i]].abs() < 1e-10 {
469            return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
470        }
471
472        // Scale pivot row
473        let pivot = aug[[i, i]];
474        for j in 0..(2 * n) {
475            aug[[i, j]] /= pivot;
476        }
477
478        // Eliminate column
479        for k in 0..n {
480            if k != i {
481                let factor = aug[[k, i]];
482                for j in 0..(2 * n) {
483                    aug[[k, j]] -= factor * aug[[i, j]];
484                }
485            }
486        }
487    }
488
489    Ok(aug.slice(s![.., n..]).to_owned())
490}
491
492/// Iterative thresholding reconstruction
493pub fn iterative_thresholding_reconstruction(
494    compressed_labels: &Array1<Float>,
495    projection_matrix: &Array2<Float>,
496) -> SklResult<Array1<Float>> {
497    let n_labels = projection_matrix.ncols();
498    let mut x = Array1::<Float>::zeros(n_labels);
499    let step_size = 0.1;
500    let threshold = 0.1;
501    let max_iterations = 100;
502
503    for _ in 0..max_iterations {
504        // Gradient step
505        let residual = projection_matrix.dot(&x) - compressed_labels;
506        let gradient = projection_matrix.t().dot(&residual);
507        x = &x - step_size * &gradient;
508
509        // Soft thresholding
510        x.mapv_inplace(|xi| {
511            if xi > threshold {
512                xi - threshold
513            } else if xi < -threshold {
514                xi + threshold
515            } else {
516                0.0
517            }
518        });
519    }
520
521    Ok(x)
522}
523
524/// Orthogonal Matching Pursuit reconstruction
525pub fn omp_reconstruction(
526    compressed_labels: &Array1<Float>,
527    projection_matrix: &Array2<Float>,
528) -> SklResult<Array1<Float>> {
529    let n_labels = projection_matrix.ncols();
530    let mut selected_indices = Vec::new();
531    let mut residual = compressed_labels.clone();
532    let max_iterations = std::cmp::min(10, n_labels); // Sparsity constraint
533
534    for _ in 0..max_iterations {
535        // Find column with maximum correlation to residual
536        let mut max_corr = 0.0;
537        let mut best_idx = 0;
538
539        for j in 0..n_labels {
540            if !selected_indices.contains(&j) {
541                let column = projection_matrix.column(j);
542                let corr = column.dot(&residual).abs();
543                if corr > max_corr {
544                    max_corr = corr;
545                    best_idx = j;
546                }
547            }
548        }
549
550        if max_corr < 1e-6 {
551            break;
552        }
553
554        selected_indices.push(best_idx);
555
556        // Solve least squares on selected columns
557        if let Ok(coeffs) =
558            solve_least_squares_subset(compressed_labels, projection_matrix, &selected_indices)
559        {
560            // Update residual
561            let mut reconstruction = Array1::<Float>::zeros(projection_matrix.nrows());
562            for (i, &idx) in selected_indices.iter().enumerate() {
563                let column = projection_matrix.column(idx);
564                reconstruction = reconstruction + coeffs[i] * &column;
565            }
566            residual = compressed_labels - &reconstruction;
567        }
568    }
569
570    // Construct final solution
571    let mut x = Array1::<Float>::zeros(n_labels);
572    if let Ok(coeffs) =
573        solve_least_squares_subset(compressed_labels, projection_matrix, &selected_indices)
574    {
575        for (i, &idx) in selected_indices.iter().enumerate() {
576            x[idx] = coeffs[i];
577        }
578    }
579
580    Ok(x)
581}
582
583/// Solve least squares problem for a subset of columns
584pub fn solve_least_squares_subset(
585    y: &Array1<Float>,
586    A: &Array2<Float>,
587    indices: &[usize],
588) -> SklResult<Array1<Float>> {
589    if indices.is_empty() {
590        return Err(SklearsError::InvalidInput(
591            "No indices provided".to_string(),
592        ));
593    }
594
595    let n_rows = A.nrows();
596    let n_selected = indices.len();
597    let mut A_subset = Array2::<Float>::zeros((n_rows, n_selected));
598
599    for (j, &idx) in indices.iter().enumerate() {
600        A_subset.column_mut(j).assign(&A.column(idx));
601    }
602
603    // Solve normal equations
604    let ata = A_subset.t().dot(&A_subset);
605    let aty = A_subset.t().dot(y);
606
607    solve_linear_system(&ata, &aty)
608}
609
610/// Train a weighted binary classifier
611pub fn train_weighted_binary_classifier_simple(
612    X: &ArrayView2<Float>,
613    y: &Array1<i32>,
614    sample_weights: &Array1<Float>,
615) -> SklResult<SimpleBinaryModel> {
616    let (n_samples, n_features) = X.dim();
617
618    if n_samples != y.len() || n_samples != sample_weights.len() {
619        return Err(SklearsError::InvalidInput(
620            "X, y, and sample_weights must have the same number of samples".to_string(),
621        ));
622    }
623
624    // Weighted means
625    let total_weight = sample_weights.sum();
626    let mut x_means = Array1::<Float>::zeros(n_features);
627    let mut y_mean = 0.0;
628
629    for i in 0..n_samples {
630        let weight = sample_weights[i];
631        y_mean += weight * y[i] as Float;
632        for j in 0..n_features {
633            x_means[j] += weight * X[[i, j]];
634        }
635    }
636
637    x_means /= total_weight;
638    y_mean /= total_weight;
639
640    // Weighted covariance computation
641    let mut weights = Array1::<Float>::zeros(n_features);
642
643    for j in 0..n_features {
644        let mut numerator = 0.0;
645        let mut x_var = 0.0;
646        let mut y_var = 0.0;
647
648        for i in 0..n_samples {
649            let weight = sample_weights[i];
650            let x_diff = X[[i, j]] - x_means[j];
651            let y_diff = y[i] as Float - y_mean;
652
653            numerator += weight * x_diff * y_diff;
654            x_var += weight * x_diff * x_diff;
655            y_var += weight * y_diff * y_diff;
656        }
657
658        let denominator = (x_var * y_var).sqrt();
659        weights[j] = if denominator > 1e-10 {
660            numerator / denominator
661        } else {
662            0.0
663        };
664    }
665
666    let bias = y_mean - weights.dot(&x_means);
667
668    // Calculate weighted accuracy
669    let mut correct_weight = 0.0;
670    for i in 0..n_samples {
671        let prediction = if weights.dot(&X.row(i)) + bias > 0.0 {
672            1
673        } else {
674            0
675        };
676        if prediction == y[i] {
677            correct_weight += sample_weights[i];
678        }
679    }
680    let accuracy = correct_weight / total_weight;
681
682    Ok(SimpleBinaryModel {
683        weights,
684        bias,
685        accuracy,
686    })
687}
688
689/// Predict binary probabilities using sigmoid function
690pub fn predict_binary_probabilities(
691    X: &ArrayView2<Float>,
692    model: &SimpleBinaryModel,
693) -> Array1<Float> {
694    let raw_scores = X.dot(&model.weights) + model.bias;
695    raw_scores.mapv(|x| 1.0 / (1.0 + (-x).exp()))
696}
697
698/// Compute cost-sensitive sample weights
699pub fn compute_cost_sensitive_weights(y: &Array2<i32>, cost_matrix: &CostMatrix) -> Array1<Float> {
700    let n_samples = y.nrows();
701    let mut weights = Array1::ones(n_samples);
702
703    for i in 0..n_samples {
704        let mut sample_weight = 1.0;
705
706        for (j, &label) in y.row(i).iter().enumerate() {
707            if j < cost_matrix.fp_costs.len() {
708                // Weight based on label imbalance and costs
709                if label == 1 {
710                    sample_weight *= cost_matrix.fn_costs[j];
711                } else {
712                    sample_weight *= cost_matrix.fp_costs[j];
713                }
714            }
715        }
716
717        weights[i] = sample_weight;
718    }
719
720    // Normalize weights to sum to n_samples
721    let total_weight = weights.sum();
722    if total_weight > 0.0 {
723        weights *= n_samples as Float / total_weight;
724    }
725
726    weights
727}
728
729/// Generate random normal distribution using Box-Muller transform
730pub fn random_normal() -> Float {
731    use std::cell::RefCell;
732
733    thread_local! {
734        static SAVED: RefCell<Option<Float>> = const { RefCell::new(None) };
735    }
736
737    SAVED.with(|saved| {
738        if let Some(value) = saved.borrow_mut().take() {
739            return value;
740        }
741
742        // Generate two uniform random numbers
743        let u1 = (rand_u32() as Float) / (u32::MAX as Float);
744        let u2 = (rand_u32() as Float) / (u32::MAX as Float);
745
746        // Box-Muller transform
747        let mag = (-2.0 * u1.ln()).sqrt();
748        let z0 = mag * (2.0 * std::f64::consts::PI * u2 as f64).cos() as Float;
749        let z1 = mag * (2.0 * std::f64::consts::PI * u2 as f64).sin() as Float;
750
751        *saved.borrow_mut() = Some(z1);
752        z0
753    })
754}
755
756/// Simple random number generator for reproducible results
757fn rand_u32() -> u32 {
758    use std::cell::RefCell;
759
760    thread_local! {
761        static STATE: RefCell<u32> = const { RefCell::new(42) };
762    }
763
764    STATE.with(|state| {
765        let mut s = state.borrow_mut();
766        *s = s.wrapping_mul(1664525).wrapping_add(1013904223);
767        *s
768    })
769}
770
771/// Train a Bayesian binary classifier
772pub fn train_bayesian_binary_classifier(
773    X: &Array2<Float>,
774    y: &Array1<i32>,
775    alpha: Float,
776) -> SklResult<BayesianBinaryModel> {
777    let (n_samples, n_features) = X.dim();
778
779    if n_samples != y.len() {
780        return Err(SklearsError::InvalidInput(
781            "X and y must have the same number of samples".to_string(),
782        ));
783    }
784
785    // Convert labels to Float and to {-1, 1}
786    let y_float: Array1<Float> = y.mapv(|x| if x == 1 { 1.0 } else { -1.0 });
787
788    // Prior precision matrix (alpha * I)
789    let prior_precision = Array2::<Float>::eye(n_features) * alpha;
790
791    // Likelihood precision (simplified)
792    let noise_precision = 1.0;
793
794    // Posterior precision = prior_precision + X^T X * noise_precision
795    let xtx = X.t().dot(X);
796    let posterior_precision = &prior_precision + &xtx * noise_precision;
797
798    // Posterior covariance = inv(posterior_precision)
799    let weight_cov = matrix_inverse(&posterior_precision)?;
800
801    // Posterior mean = posterior_cov * X^T * y * noise_precision
802    let xty = X.t().dot(&y_float);
803    let weight_mean = weight_cov.dot(&(&xty * noise_precision));
804
805    // Bias parameters (simplified)
806    let bias_mean = 0.0;
807    let bias_var = 1.0 / alpha;
808
809    Ok(BayesianBinaryModel {
810        weight_mean,
811        weight_cov,
812        bias_mean,
813        bias_var,
814        noise_precision,
815    })
816}
817
818/// Predict with Bayesian binary classifier (mean prediction)
819pub fn predict_bayesian_binary(
820    X: &ArrayView2<Float>,
821    model: &BayesianBinaryModel,
822) -> Array1<Float> {
823    let raw_scores = X.dot(&model.weight_mean) + model.bias_mean;
824    raw_scores.mapv(|x| 1.0 / (1.0 + (-x).exp()))
825}
826
827/// Predict with uncertainty quantification
828pub fn predict_bayesian_uncertainty(
829    X: &ArrayView2<Float>,
830    model: &BayesianBinaryModel,
831) -> SklResult<(Array1<Float>, Array1<Float>)> {
832    let n_samples = X.nrows();
833    let mut means = Array1::<Float>::zeros(n_samples);
834    let mut variances = Array1::<Float>::zeros(n_samples);
835
836    for i in 0..n_samples {
837        let x = X.row(i);
838
839        // Mean prediction
840        let mean_score = x.dot(&model.weight_mean) + model.bias_mean;
841        means[i] = 1.0 / (1.0 + (-mean_score).exp());
842
843        // Variance calculation
844        let score_var =
845            x.dot(&model.weight_cov.dot(&x)) + model.bias_var + 1.0 / model.noise_precision;
846        variances[i] = score_var;
847    }
848
849    Ok((means, variances))
850}
851
852/// Predict mean of Bayesian model
853pub fn predict_bayesian_mean(X: &ArrayView2<Float>, model: &BayesianBinaryModel) -> Array1<Float> {
854    X.dot(&model.weight_mean) + model.bias_mean
855}