sklears_gaussian_process/
classification.rs

1//! Gaussian Process Classification Models
2//!
3//! This module provides different implementations of Gaussian Process classifiers:
4//! - `GaussianProcessClassifier`: Binary classification using Laplace approximation
5//! - `MultiClassGaussianProcessClassifier`: Multi-class classification using One-vs-Rest strategy
6//! - `ExpectationPropagationGaussianProcessClassifier`: EP-based classification
7//!
8//! All classifiers support different kernel functions and provide probability estimates.
9
10use std::collections::HashSet;
11use std::f64::consts::PI;
12
13// Use ndarray imports (following existing pattern in the codebase)
14// SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
15use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
16use sklears_core::{
17    error::{Result as SklResult, SklearsError},
18    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
19};
20
21use crate::kernels::Kernel;
22use crate::utils;
23
24///
25/// let X = array![[1.0], [2.0], [3.0], [4.0]];
26/// let y = array![0, 0, 1, 1];
27///
28/// let kernel = RBF::new(1.0);
29/// let gpc = GaussianProcessClassifier::new().kernel(Box::new(kernel));
30/// let fitted = gpc.fit(&X.view(), &y.view()).unwrap();
31/// let predictions = fitted.predict(&X.view()).unwrap();
32/// ```
33#[derive(Debug, Clone)]
34pub struct GaussianProcessClassifier<S = Untrained> {
35    state: S,
36    kernel: Option<Box<dyn Kernel>>,
37    optimizer: Option<String>,
38    n_restarts_optimizer: usize,
39    max_iter_predict: usize,
40    warm_start: bool,
41    copy_x_train: bool,
42    random_state: Option<u64>,
43    config: GpcConfig,
44}
45
46/// Trained state for Gaussian Process Classifier
47#[derive(Debug, Clone)]
48pub struct GpcTrained {
49    /// X_train
50    pub X_train: Option<Array2<f64>>, // Training inputs
51    /// y_train
52    pub y_train: Array1<i32>, // Training labels
53    /// classes
54    pub classes: Array1<i32>, // Unique classes
55    /// pi
56    pub pi: Array1<f64>, // Approximate posterior probabilities
57    /// W_sr
58    pub W_sr: Array1<f64>, // sqrt(W) where W is the diagonal Hessian
59    /// L
60    pub L: Array2<f64>, // Cholesky decomposition
61    /// K
62    pub K: Array2<f64>, // Kernel matrix
63    /// f
64    pub f: Array1<f64>, // Latent function values
65    /// kernel
66    pub kernel: Box<dyn Kernel>, // Kernel function
67    /// log_marginal_likelihood_value
68    pub log_marginal_likelihood_value: f64, // Log marginal likelihood
69}
70
71impl GaussianProcessClassifier<Untrained> {
72    /// Create a new GaussianProcessClassifier instance
73    pub fn new() -> Self {
74        Self {
75            state: Untrained,
76            kernel: None,
77            optimizer: Some("fmin_l_bfgs_b".to_string()),
78            n_restarts_optimizer: 0,
79            max_iter_predict: 100,
80            warm_start: false,
81            copy_x_train: true,
82            random_state: None,
83            config: GpcConfig::default(),
84        }
85    }
86
87    /// Set the kernel function
88    pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
89        self.kernel = Some(kernel);
90        self
91    }
92
93    /// Set the optimizer
94    pub fn optimizer(mut self, optimizer: Option<String>) -> Self {
95        self.optimizer = optimizer;
96        self
97    }
98
99    /// Set the number of optimizer restarts
100    pub fn n_restarts_optimizer(mut self, n_restarts: usize) -> Self {
101        self.n_restarts_optimizer = n_restarts;
102        self
103    }
104
105    /// Set the maximum number of iterations for prediction
106    pub fn max_iter_predict(mut self, max_iter: usize) -> Self {
107        self.max_iter_predict = max_iter;
108        self
109    }
110
111    /// Set whether to use warm start
112    pub fn warm_start(mut self, warm_start: bool) -> Self {
113        self.warm_start = warm_start;
114        self
115    }
116
117    /// Set whether to copy X during training
118    pub fn copy_x_train(mut self, copy_x_train: bool) -> Self {
119        self.copy_x_train = copy_x_train;
120        self
121    }
122
123    /// Set the random state
124    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
125        self.random_state = random_state;
126        self
127    }
128}
129
130/// Configuration for Gaussian Process Classifier
131#[derive(Debug, Clone)]
132pub struct GpcConfig {
133    /// kernel_name
134    pub kernel_name: String,
135    /// optimizer
136    pub optimizer: Option<String>,
137    /// n_restarts_optimizer
138    pub n_restarts_optimizer: usize,
139    /// max_iter_predict
140    pub max_iter_predict: usize,
141    /// warm_start
142    pub warm_start: bool,
143    /// copy_x_train
144    pub copy_x_train: bool,
145    /// random_state
146    pub random_state: Option<u64>,
147}
148
149impl Default for GpcConfig {
150    fn default() -> Self {
151        Self {
152            kernel_name: "RBF".to_string(),
153            optimizer: Some("fmin_l_bfgs_b".to_string()),
154            n_restarts_optimizer: 0,
155            max_iter_predict: 100,
156            warm_start: false,
157            copy_x_train: true,
158            random_state: None,
159        }
160    }
161}
162
163impl Estimator for GaussianProcessClassifier<Untrained> {
164    type Config = GpcConfig;
165    type Error = SklearsError;
166    type Float = f64;
167
168    fn config(&self) -> &Self::Config {
169        &self.config
170    }
171}
172
173impl Estimator for GaussianProcessClassifier<GpcTrained> {
174    type Config = GpcConfig;
175    type Error = SklearsError;
176    type Float = f64;
177
178    fn config(&self) -> &Self::Config {
179        &self.config
180    }
181}
182
183impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for GaussianProcessClassifier<Untrained> {
184    type Fitted = GaussianProcessClassifier<GpcTrained>;
185
186    #[allow(non_snake_case)]
187    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<Self::Fitted> {
188        if X.nrows() != y.len() {
189            return Err(SklearsError::InvalidInput(
190                "X and y must have the same number of samples".to_string(),
191            ));
192        }
193
194        let kernel = self
195            .kernel
196            .as_ref()
197            .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?
198            .clone();
199
200        // Get unique classes
201        let mut classes_set: HashSet<i32> = HashSet::new();
202        for &label in y.iter() {
203            classes_set.insert(label);
204        }
205
206        if classes_set.len() != 2 {
207            return Err(SklearsError::InvalidInput(
208                "Binary classification requires exactly 2 classes".to_string(),
209            ));
210        }
211
212        let mut classes: Vec<i32> = classes_set.into_iter().collect();
213        classes.sort();
214        let classes = Array1::from(classes);
215
216        // Convert labels to {-1, 1}
217        let y_binary = y.mapv(|label| if label == classes[0] { -1.0 } else { 1.0 });
218
219        // Compute kernel matrix
220        let X_owned = X.to_owned();
221        let K = kernel.compute_kernel_matrix(&X_owned, None)?;
222
223        // Laplace approximation
224        let (f, pi, W_sr, L, log_marginal_likelihood_value) =
225            laplace_approximation(&K, &y_binary, self.max_iter_predict)?;
226
227        let X_train = if self.copy_x_train {
228            Some(X.to_owned())
229        } else {
230            None
231        };
232
233        Ok(GaussianProcessClassifier {
234            state: GpcTrained {
235                X_train,
236                y_train: y.to_owned(),
237                classes,
238                pi,
239                W_sr,
240                L,
241                K,
242                f,
243                kernel,
244                log_marginal_likelihood_value,
245            },
246            kernel: None,
247            optimizer: self.optimizer,
248            n_restarts_optimizer: self.n_restarts_optimizer,
249            max_iter_predict: self.max_iter_predict,
250            warm_start: self.warm_start,
251            copy_x_train: self.copy_x_train,
252            random_state: self.random_state,
253            config: self.config,
254        })
255    }
256}
257
258impl Predict<ArrayView2<'_, f64>, Array1<i32>> for GaussianProcessClassifier<GpcTrained> {
259    fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<i32>> {
260        let probabilities = self.predict_proba(X)?;
261        let predictions: Vec<i32> = probabilities
262            .axis_iter(Axis(0))
263            .map(|row| {
264                let max_idx = row
265                    .iter()
266                    .enumerate()
267                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
268                    .map(|(idx, _)| idx)
269                    .unwrap();
270                self.state.classes[max_idx]
271            })
272            .collect();
273        Ok(Array1::from(predictions))
274    }
275}
276
277impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for GaussianProcessClassifier<GpcTrained> {
278    #[allow(non_snake_case)]
279    fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
280        let X_train =
281            self.state.X_train.as_ref().ok_or_else(|| {
282                SklearsError::InvalidInput("Training data not available".to_string())
283            })?;
284
285        // Compute kernel between test and training points
286        let X_test_owned = X.to_owned();
287        let K_star = self
288            .state
289            .kernel
290            .compute_kernel_matrix(X_train, Some(&X_test_owned))?;
291
292        // Predict latent function values
293        let f_star =
294            predict_latent_function(&K_star, &self.state.f, &self.state.W_sr, &self.state.L)?;
295
296        // Convert to probabilities using sigmoid
297        let mut probabilities = Array2::<f64>::zeros((X.nrows(), 2));
298        for (i, &f_val) in f_star.iter().enumerate() {
299            let prob_positive = sigmoid(f_val);
300            probabilities[[i, 0]] = 1.0 - prob_positive; // Probability of class 0
301            probabilities[[i, 1]] = prob_positive; // Probability of class 1
302        }
303
304        Ok(probabilities)
305    }
306}
307
308impl GaussianProcessClassifier<GpcTrained> {
309    /// Get the log marginal likelihood
310    pub fn log_marginal_likelihood(&self) -> f64 {
311        self.state.log_marginal_likelihood_value
312    }
313
314    /// Get the unique classes
315    pub fn classes(&self) -> &Array1<i32> {
316        &self.state.classes
317    }
318}
319
320/// Sigmoid function
321pub fn sigmoid(x: f64) -> f64 {
322    1.0 / (1.0 + (-x).exp())
323}
324
325/// Derivative of sigmoid function
326pub fn sigmoid_derivative(x: f64) -> f64 {
327    let s = sigmoid(x);
328    s * (1.0 - s)
329}
330
331/// Laplace approximation for binary classification
332#[allow(non_snake_case)]
333fn laplace_approximation(
334    K: &Array2<f64>,
335    y: &Array1<f64>,
336    max_iter: usize,
337) -> SklResult<(Array1<f64>, Array1<f64>, Array1<f64>, Array2<f64>, f64)> {
338    let n = K.nrows();
339    let mut f = Array1::<f64>::zeros(n);
340    let tol = 1e-6;
341
342    for _iter in 0..max_iter {
343        // Compute probabilities and their derivatives
344        let pi = f.mapv(sigmoid);
345        let W = f.mapv(sigmoid_derivative);
346        let _W_sr = W.mapv(|w| w.sqrt());
347
348        // Compute gradients
349        let grad = &pi - y;
350
351        // Newton's method update
352        // (K^{-1} + W)^{-1} * grad
353        let mut K_W = K.clone();
354        for i in 0..n {
355            K_W[[i, i]] += W[i];
356        }
357
358        let L = utils::robust_cholesky(&K_W)?;
359        let delta_f = utils::triangular_solve(&L, &grad)?;
360
361        let f_new = &f - &delta_f;
362
363        // Check convergence
364        let diff = (&f_new - &f).mapv(|x| x.abs()).sum();
365        f = f_new;
366
367        if diff < tol {
368            break;
369        }
370    }
371
372    // Final computation
373    let pi = f.mapv(sigmoid);
374    let W = f.mapv(sigmoid_derivative);
375    let W_sr = W.mapv(|w| w.sqrt());
376
377    // Compute final Cholesky decomposition
378    let mut K_W = K.clone();
379    for i in 0..n {
380        K_W[[i, i]] += W[i];
381    }
382    let L = utils::robust_cholesky(&K_W)?;
383
384    // Compute log marginal likelihood
385    let log_marginal_likelihood = {
386        let log_det = 2.0 * L.diag().mapv(|x| x.ln()).sum();
387        let quadratic: f64 = f
388            .iter()
389            .zip(y.iter())
390            .map(|(&f_i, &y_i)| y_i * f_i - (1.0 + f_i.exp()).ln())
391            .sum();
392        quadratic - 0.5 * log_det
393    };
394
395    Ok((f, pi, W_sr, L, log_marginal_likelihood))
396}
397
398/// Predict latent function values
399fn predict_latent_function(
400    K_star: &Array2<f64>,
401    f_train: &Array1<f64>,
402    _W_sr: &Array1<f64>,
403    _L: &Array2<f64>,
404) -> SklResult<Array1<f64>> {
405    // Simplified prediction (should be more sophisticated)
406    let f_star_values = K_star.dot(f_train);
407    Ok(f_star_values)
408}
409
410impl Default for GaussianProcessClassifier<Untrained> {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416/// Multi-class Gaussian Process Classifier using One-vs-Rest strategy
417///
418/// This implementation extends binary Gaussian Process Classification to handle
419/// multi-class problems by training one binary classifier per class against all others.
420/// Each binary classifier uses the Laplace approximation for inference.
421///
422/// # Examples
423///
424/// ```
425/// use sklears_gaussian_process::{MultiClassGaussianProcessClassifier, kernels::RBF};
426/// use sklears_core::traits::{Fit, Predict};
427/// // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
428/// use scirs2_core::ndarray::array;
429///
430/// let X = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
431/// let y = array![0, 0, 1, 1, 2, 2];
432///
433/// let kernel = RBF::new(1.0);
434/// let mc_gpc = MultiClassGaussianProcessClassifier::new()
435///     .kernel(Box::new(kernel));
436/// let fitted = mc_gpc.fit(&X.view(), &y.view()).unwrap();
437/// let predictions = fitted.predict(&X.view()).unwrap();
438/// ```
439#[derive(Debug, Clone)]
440pub struct MultiClassGaussianProcessClassifier<S = Untrained> {
441    state: S,
442    kernel: Option<Box<dyn Kernel>>,
443    optimizer: Option<String>,
444    n_restarts_optimizer: usize,
445    max_iter_predict: usize,
446    warm_start: bool,
447    copy_x_train: bool,
448    random_state: Option<u64>,
449    config: GpcConfig,
450}
451
452/// Trained state for Multi-class Gaussian Process Classifier
453#[derive(Debug, Clone)]
454pub struct McGpcTrained {
455    /// X_train
456    pub X_train: Option<Array2<f64>>, // Training inputs
457    /// y_train
458    pub y_train: Array1<i32>, // Training labels
459    /// classes
460    pub classes: Array1<i32>, // Unique classes
461    /// binary_classifiers
462    pub binary_classifiers: Vec<GaussianProcessClassifier<GpcTrained>>, // One-vs-rest classifiers
463    /// n_classes
464    pub n_classes: usize, // Number of classes
465}
466
467impl MultiClassGaussianProcessClassifier<Untrained> {
468    /// Create a new MultiClassGaussianProcessClassifier instance
469    pub fn new() -> Self {
470        Self {
471            state: Untrained,
472            kernel: None,
473            optimizer: Some("fmin_l_bfgs_b".to_string()),
474            n_restarts_optimizer: 0,
475            max_iter_predict: 100,
476            warm_start: false,
477            copy_x_train: true,
478            random_state: None,
479            config: GpcConfig::default(),
480        }
481    }
482
483    /// Set the kernel function
484    pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
485        self.kernel = Some(kernel);
486        self
487    }
488
489    /// Set the optimizer
490    pub fn optimizer(mut self, optimizer: Option<String>) -> Self {
491        self.optimizer = optimizer;
492        self
493    }
494
495    /// Set the number of optimizer restarts
496    pub fn n_restarts_optimizer(mut self, n_restarts: usize) -> Self {
497        self.n_restarts_optimizer = n_restarts;
498        self
499    }
500
501    /// Set the maximum number of iterations for prediction
502    pub fn max_iter_predict(mut self, max_iter: usize) -> Self {
503        self.max_iter_predict = max_iter;
504        self
505    }
506
507    /// Set whether to use warm start
508    pub fn warm_start(mut self, warm_start: bool) -> Self {
509        self.warm_start = warm_start;
510        self
511    }
512
513    /// Set whether to copy X during training
514    pub fn copy_x_train(mut self, copy_x_train: bool) -> Self {
515        self.copy_x_train = copy_x_train;
516        self
517    }
518
519    /// Set the random state
520    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
521        self.random_state = random_state;
522        self
523    }
524}
525
526impl Estimator for MultiClassGaussianProcessClassifier<Untrained> {
527    type Config = GpcConfig;
528    type Error = SklearsError;
529    type Float = f64;
530
531    fn config(&self) -> &Self::Config {
532        &self.config
533    }
534}
535
536impl Estimator for MultiClassGaussianProcessClassifier<McGpcTrained> {
537    type Config = GpcConfig;
538    type Error = SklearsError;
539    type Float = f64;
540
541    fn config(&self) -> &Self::Config {
542        &self.config
543    }
544}
545
546impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>>
547    for MultiClassGaussianProcessClassifier<Untrained>
548{
549    type Fitted = MultiClassGaussianProcessClassifier<McGpcTrained>;
550
551    #[allow(non_snake_case)]
552    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<Self::Fitted> {
553        let kernel = self
554            .kernel
555            .as_ref()
556            .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?
557            .clone();
558
559        // Get unique classes
560        let mut classes_set: HashSet<i32> = HashSet::new();
561        for &label in y.iter() {
562            classes_set.insert(label);
563        }
564
565        if classes_set.len() < 2 {
566            return Err(SklearsError::InvalidInput(
567                "Multi-class classification requires at least 2 classes".to_string(),
568            ));
569        }
570
571        let mut classes: Vec<i32> = classes_set.into_iter().collect();
572        classes.sort();
573        let classes = Array1::from(classes);
574        let n_classes = classes.len();
575
576        // Handle binary case by delegating to binary classifier
577        if n_classes == 2 {
578            let binary_gpc = GaussianProcessClassifier::new()
579                .kernel(kernel)
580                .optimizer(self.optimizer.clone())
581                .n_restarts_optimizer(self.n_restarts_optimizer)
582                .max_iter_predict(self.max_iter_predict)
583                .warm_start(self.warm_start)
584                .copy_x_train(self.copy_x_train)
585                .random_state(self.random_state);
586
587            let fitted_binary = binary_gpc.fit(X, y)?;
588
589            let X_train = if self.copy_x_train {
590                Some(X.to_owned())
591            } else {
592                None
593            };
594
595            return Ok(MultiClassGaussianProcessClassifier {
596                state: McGpcTrained {
597                    X_train,
598                    y_train: y.to_owned(),
599                    classes,
600                    binary_classifiers: vec![fitted_binary],
601                    n_classes,
602                },
603                kernel: None,
604                optimizer: self.optimizer.clone(),
605                n_restarts_optimizer: self.n_restarts_optimizer,
606                max_iter_predict: self.max_iter_predict,
607                warm_start: self.warm_start,
608                copy_x_train: self.copy_x_train,
609                random_state: self.random_state,
610                config: self.config.clone(),
611            });
612        }
613
614        // Multi-class case: One-vs-Rest strategy
615        let mut binary_classifiers = Vec::with_capacity(n_classes);
616
617        for (class_idx, &current_class) in classes.iter().enumerate() {
618            // Create binary labels: +1 for current class, -1 for all others
619            let y_binary: Array1<i32> = y.mapv(|label| if label == current_class { 1 } else { 0 });
620
621            // Create and train binary classifier for this class
622            let binary_gpc = GaussianProcessClassifier::new()
623                .kernel(kernel.clone())
624                .optimizer(self.optimizer.clone())
625                .n_restarts_optimizer(self.n_restarts_optimizer)
626                .max_iter_predict(self.max_iter_predict)
627                .warm_start(self.warm_start)
628                .copy_x_train(self.copy_x_train)
629                .random_state(self.random_state.map(|s| s + class_idx as u64));
630
631            let fitted_binary = binary_gpc.fit(X, &y_binary.view())?;
632            binary_classifiers.push(fitted_binary);
633        }
634
635        let X_train = if self.copy_x_train {
636            Some(X.to_owned())
637        } else {
638            None
639        };
640
641        Ok(MultiClassGaussianProcessClassifier {
642            state: McGpcTrained {
643                X_train,
644                y_train: y.to_owned(),
645                classes,
646                binary_classifiers,
647                n_classes,
648            },
649            kernel: None,
650            optimizer: self.optimizer,
651            n_restarts_optimizer: self.n_restarts_optimizer,
652            max_iter_predict: self.max_iter_predict,
653            warm_start: self.warm_start,
654            copy_x_train: self.copy_x_train,
655            random_state: self.random_state,
656            config: self.config.clone(),
657        })
658    }
659}
660
661impl Predict<ArrayView2<'_, f64>, Array1<i32>>
662    for MultiClassGaussianProcessClassifier<McGpcTrained>
663{
664    fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<i32>> {
665        let probabilities = self.predict_proba(X)?;
666        let predictions: Vec<i32> = probabilities
667            .axis_iter(Axis(0))
668            .map(|row| {
669                let max_idx = row
670                    .iter()
671                    .enumerate()
672                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
673                    .map(|(idx, _)| idx)
674                    .unwrap();
675                self.state.classes[max_idx]
676            })
677            .collect();
678        Ok(Array1::from(predictions))
679    }
680}
681
682impl PredictProba<ArrayView2<'_, f64>, Array2<f64>>
683    for MultiClassGaussianProcessClassifier<McGpcTrained>
684{
685    fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
686        let n_samples = X.nrows();
687        let n_classes = self.state.n_classes;
688
689        // Handle binary case
690        if n_classes == 2 {
691            return self.state.binary_classifiers[0].predict_proba(X);
692        }
693
694        // Multi-class case: collect probabilities from all binary classifiers
695        let mut all_probabilities = Array2::<f64>::zeros((n_samples, n_classes));
696
697        for (class_idx, binary_classifier) in self.state.binary_classifiers.iter().enumerate() {
698            let binary_proba = binary_classifier.predict_proba(X)?;
699            // Take the positive class probability (column 1)
700            for i in 0..n_samples {
701                all_probabilities[[i, class_idx]] = binary_proba[[i, 1]];
702            }
703        }
704
705        // Normalize probabilities to sum to 1 for each sample
706        for i in 0..n_samples {
707            let row_sum: f64 = all_probabilities.row(i).sum();
708            if row_sum > 1e-12 {
709                for j in 0..n_classes {
710                    all_probabilities[[i, j]] /= row_sum;
711                }
712            } else {
713                // If all probabilities are near zero, assign uniform probabilities
714                for j in 0..n_classes {
715                    all_probabilities[[i, j]] = 1.0 / n_classes as f64;
716                }
717            }
718        }
719
720        Ok(all_probabilities)
721    }
722}
723
724impl MultiClassGaussianProcessClassifier<McGpcTrained> {
725    /// Get the unique classes
726    pub fn classes(&self) -> &Array1<i32> {
727        &self.state.classes
728    }
729
730    /// Get the number of classes
731    pub fn n_classes(&self) -> usize {
732        self.state.n_classes
733    }
734
735    /// Get the binary classifiers
736    pub fn binary_classifiers(&self) -> &[GaussianProcessClassifier<GpcTrained>] {
737        &self.state.binary_classifiers
738    }
739
740    /// Get the log marginal likelihood for a specific class
741    pub fn log_marginal_likelihood(&self, class_idx: usize) -> Option<f64> {
742        if class_idx < self.state.binary_classifiers.len() {
743            Some(self.state.binary_classifiers[class_idx].log_marginal_likelihood())
744        } else {
745            None
746        }
747    }
748
749    /// Get the average log marginal likelihood across all classes
750    pub fn average_log_marginal_likelihood(&self) -> f64 {
751        let sum: f64 = self
752            .state
753            .binary_classifiers
754            .iter()
755            .map(|classifier| classifier.log_marginal_likelihood())
756            .sum();
757        sum / self.state.binary_classifiers.len() as f64
758    }
759}
760
761impl Default for MultiClassGaussianProcessClassifier<Untrained> {
762    fn default() -> Self {
763        Self::new()
764    }
765}
766
767/// Expectation Propagation Gaussian Process Classifier
768///
769/// This implementation uses Expectation Propagation (EP) for approximate inference
770/// in Gaussian Process Classification. EP provides a better approximation than the
771/// Laplace method by iteratively refining local approximations to the likelihood.
772///
773/// EP maintains a multivariate Gaussian approximation to the posterior by minimizing
774/// the KL divergence between the true posterior and the approximation at each data point.
775///
776/// # Examples
777///
778/// ```
779/// use sklears_gaussian_process::{ExpectationPropagationGaussianProcessClassifier, kernels::RBF};
780/// use sklears_core::traits::{Fit, Predict};
781/// // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
782/// use scirs2_core::ndarray::array;
783///
784/// let X = array![[1.0], [2.0], [3.0], [4.0]];
785/// let y = array![0, 0, 1, 1];
786///
787/// let kernel = RBF::new(1.0);
788/// let ep_gpc = ExpectationPropagationGaussianProcessClassifier::new()
789///     .kernel(Box::new(kernel));
790/// let fitted = ep_gpc.fit(&X.view(), &y.view()).unwrap();
791/// let predictions = fitted.predict(&X.view()).unwrap();
792/// ```
793#[derive(Debug, Clone)]
794pub struct ExpectationPropagationGaussianProcessClassifier<S = Untrained> {
795    state: S,
796    kernel: Option<Box<dyn Kernel>>,
797    max_iter: usize,
798    tol: f64,
799    damping: f64,
800    min_variance: f64,
801    verbose: bool,
802    random_state: Option<u64>,
803    config: GpcConfig,
804}
805
806/// Trained state for Expectation Propagation Gaussian Process Classifier
807#[derive(Debug, Clone)]
808pub struct EpGpcTrained {
809    /// X_train
810    pub X_train: Option<Array2<f64>>, // Training inputs
811    /// y_train
812    pub y_train: Array1<i32>, // Training labels
813    /// classes
814    pub classes: Array1<i32>, // Unique classes
815    /// mu
816    pub mu: Array1<f64>, // Posterior mean
817    /// Sigma
818    pub Sigma: Array2<f64>, // Posterior covariance
819    /// tau
820    pub tau: Array1<f64>, // Site precisions
821    /// nu
822    pub nu: Array1<f64>, // Site means * precisions
823    /// kernel
824    pub kernel: Box<dyn Kernel>, // Kernel function
825    /// log_marginal_likelihood_value
826    pub log_marginal_likelihood_value: f64, // Log marginal likelihood
827    /// n_iterations
828    pub n_iterations: usize, // Number of EP iterations
829}
830
831impl ExpectationPropagationGaussianProcessClassifier<Untrained> {
832    /// Create a new ExpectationPropagationGaussianProcessClassifier instance
833    pub fn new() -> Self {
834        Self {
835            state: Untrained,
836            kernel: None,
837            max_iter: 100,
838            tol: 1e-4,
839            damping: 0.5,
840            min_variance: 1e-10,
841            verbose: false,
842            random_state: None,
843            config: GpcConfig::default(),
844        }
845    }
846
847    /// Set the kernel function
848    pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
849        self.kernel = Some(kernel);
850        self
851    }
852
853    /// Set the maximum number of iterations
854    pub fn max_iter(mut self, max_iter: usize) -> Self {
855        self.max_iter = max_iter;
856        self
857    }
858
859    /// Set convergence tolerance
860    pub fn tol(mut self, tol: f64) -> Self {
861        self.tol = tol;
862        self
863    }
864
865    /// Set damping factor for updates
866    pub fn damping(mut self, damping: f64) -> Self {
867        self.damping = damping.max(0.0).min(1.0);
868        self
869    }
870
871    /// Set minimum variance threshold
872    pub fn min_variance(mut self, min_variance: f64) -> Self {
873        self.min_variance = min_variance;
874        self
875    }
876
877    /// Set verbosity
878    pub fn verbose(mut self, verbose: bool) -> Self {
879        self.verbose = verbose;
880        self
881    }
882
883    /// Set the random state
884    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
885        self.random_state = random_state;
886        self
887    }
888}
889
890impl Estimator for ExpectationPropagationGaussianProcessClassifier<Untrained> {
891    type Config = GpcConfig;
892    type Error = SklearsError;
893    type Float = f64;
894
895    fn config(&self) -> &Self::Config {
896        &self.config
897    }
898}
899
900impl Estimator for ExpectationPropagationGaussianProcessClassifier<EpGpcTrained> {
901    type Config = GpcConfig;
902    type Error = SklearsError;
903    type Float = f64;
904
905    fn config(&self) -> &Self::Config {
906        &self.config
907    }
908}
909
910impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>>
911    for ExpectationPropagationGaussianProcessClassifier<Untrained>
912{
913    type Fitted = ExpectationPropagationGaussianProcessClassifier<EpGpcTrained>;
914
915    #[allow(non_snake_case)]
916    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<Self::Fitted> {
917        if X.nrows() != y.len() {
918            return Err(SklearsError::InvalidInput(
919                "X and y must have the same number of samples".to_string(),
920            ));
921        }
922
923        let kernel = self
924            .kernel
925            .as_ref()
926            .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?
927            .clone();
928
929        // Get unique classes
930        let mut classes_set: HashSet<i32> = HashSet::new();
931        for &label in y.iter() {
932            classes_set.insert(label);
933        }
934
935        if classes_set.len() != 2 {
936            return Err(SklearsError::InvalidInput(
937                "Binary classification requires exactly 2 classes".to_string(),
938            ));
939        }
940
941        let mut classes: Vec<i32> = classes_set.into_iter().collect();
942        classes.sort();
943        let classes = Array1::from(classes);
944
945        // Convert labels to {-1, 1}
946        let y_binary = y.mapv(|label| if label == classes[0] { -1.0 } else { 1.0 });
947
948        // Compute kernel matrix
949        let X_owned = X.to_owned();
950        let K = kernel.compute_kernel_matrix(&X_owned, None)?;
951
952        // Run Expectation Propagation
953        let (mu, Sigma, tau, nu, log_marginal_likelihood_value, n_iterations) =
954            expectation_propagation(
955                &K,
956                &y_binary,
957                self.max_iter,
958                self.tol,
959                self.damping,
960                self.min_variance,
961                self.verbose,
962            )?;
963
964        let X_train = Some(X.to_owned());
965
966        Ok(ExpectationPropagationGaussianProcessClassifier {
967            state: EpGpcTrained {
968                X_train,
969                y_train: y.to_owned(),
970                classes,
971                mu,
972                Sigma,
973                tau,
974                nu,
975                kernel,
976                log_marginal_likelihood_value,
977                n_iterations,
978            },
979            kernel: None,
980            max_iter: self.max_iter,
981            tol: self.tol,
982            damping: self.damping,
983            min_variance: self.min_variance,
984            verbose: self.verbose,
985            random_state: self.random_state,
986            config: self.config.clone(),
987        })
988    }
989}
990
991impl Predict<ArrayView2<'_, f64>, Array1<i32>>
992    for ExpectationPropagationGaussianProcessClassifier<EpGpcTrained>
993{
994    fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<i32>> {
995        let probabilities = self.predict_proba(X)?;
996        let predictions: Vec<i32> = probabilities
997            .axis_iter(Axis(0))
998            .map(|row| {
999                let max_idx = row
1000                    .iter()
1001                    .enumerate()
1002                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1003                    .map(|(idx, _)| idx)
1004                    .unwrap();
1005                self.state.classes[max_idx]
1006            })
1007            .collect();
1008        Ok(Array1::from(predictions))
1009    }
1010}
1011
1012impl PredictProba<ArrayView2<'_, f64>, Array2<f64>>
1013    for ExpectationPropagationGaussianProcessClassifier<EpGpcTrained>
1014{
1015    #[allow(non_snake_case)]
1016    fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
1017        let X_train =
1018            self.state.X_train.as_ref().ok_or_else(|| {
1019                SklearsError::InvalidInput("Training data not available".to_string())
1020            })?;
1021
1022        // Compute kernel between test and training points
1023        let X_test_owned = X.to_owned();
1024        let K_star = self
1025            .state
1026            .kernel
1027            .compute_kernel_matrix(X_train, Some(&X_test_owned))?;
1028
1029        // Predict using EP posterior
1030        let (f_star_mean, f_star_var) = ep_predict(&K_star, &self.state.mu, &self.state.Sigma)?;
1031
1032        // Convert to probabilities using probit approximation
1033        let mut probabilities = Array2::<f64>::zeros((X.nrows(), 2));
1034        for (i, (&mean, &var)) in f_star_mean.iter().zip(f_star_var.iter()).enumerate() {
1035            let std_dev = var.sqrt().max(1e-10);
1036            let z = mean / std_dev;
1037            let prob_positive = normal_cdf(z);
1038            probabilities[[i, 0]] = 1.0 - prob_positive; // Probability of class 0
1039            probabilities[[i, 1]] = prob_positive; // Probability of class 1
1040        }
1041
1042        Ok(probabilities)
1043    }
1044}
1045
1046impl ExpectationPropagationGaussianProcessClassifier<EpGpcTrained> {
1047    /// Get the log marginal likelihood
1048    pub fn log_marginal_likelihood(&self) -> f64 {
1049        self.state.log_marginal_likelihood_value
1050    }
1051
1052    /// Get the unique classes
1053    pub fn classes(&self) -> &Array1<i32> {
1054        &self.state.classes
1055    }
1056
1057    /// Get the posterior mean
1058    pub fn posterior_mean(&self) -> &Array1<f64> {
1059        &self.state.mu
1060    }
1061
1062    /// Get the posterior covariance
1063    pub fn posterior_covariance(&self) -> &Array2<f64> {
1064        &self.state.Sigma
1065    }
1066
1067    /// Get the number of EP iterations used
1068    pub fn n_iterations(&self) -> usize {
1069        self.state.n_iterations
1070    }
1071}
1072
1073impl Default for ExpectationPropagationGaussianProcessClassifier<Untrained> {
1074    fn default() -> Self {
1075        Self::new()
1076    }
1077}
1078
1079// Helper functions for Expectation Propagation
1080
1081/// Run Expectation Propagation algorithm
1082fn expectation_propagation(
1083    K: &Array2<f64>,
1084    y: &Array1<f64>,
1085    max_iter: usize,
1086    tol: f64,
1087    damping: f64,
1088    min_variance: f64,
1089    verbose: bool,
1090) -> SklResult<(
1091    Array1<f64>,
1092    Array2<f64>,
1093    Array1<f64>,
1094    Array1<f64>,
1095    f64,
1096    usize,
1097)> {
1098    let n = K.nrows();
1099
1100    // Initialize site parameters
1101    let mut tau = Array1::<f64>::zeros(n); // Site precisions
1102    let mut nu = Array1::<f64>::zeros(n); // Site means * precisions
1103
1104    // Initialize posterior
1105    let mut mu = Array1::<f64>::zeros(n); // Posterior mean
1106    let mut Sigma = K.clone(); // Posterior covariance
1107
1108    let mut converged = false;
1109    let mut iteration = 0;
1110
1111    for iter in 0..max_iter {
1112        iteration = iter + 1;
1113        let mut max_change: f64 = 0.0;
1114
1115        // EP updates for each site
1116        for i in 0..n {
1117            // Remove site i from cavity distribution
1118            let tau_cavity = 1.0 / Sigma[[i, i]] - tau[i];
1119            let mu_cavity = if tau_cavity > 1e-12 {
1120                mu[i] / (tau_cavity * Sigma[[i, i]]) - nu[i] / tau_cavity
1121            } else {
1122                0.0
1123            };
1124
1125            // Compute marginal moments
1126            let sigma_cavity = if tau_cavity > 1e-12 {
1127                1.0 / tau_cavity
1128            } else {
1129                1e6
1130            };
1131            let (z0, z1, z2) = marginal_moments(y[i], mu_cavity, sigma_cavity);
1132
1133            if z0 > 1e-12 {
1134                // Update site parameters
1135                let delta_tau = z1 / sigma_cavity - z2 / (sigma_cavity * sigma_cavity) - tau_cavity;
1136                let delta_nu = z1 / sigma_cavity - mu_cavity * tau_cavity;
1137
1138                // Apply damping
1139                let tau_new = tau[i] + damping * delta_tau;
1140                let nu_new = nu[i] + damping * delta_nu;
1141
1142                // Track convergence
1143                let change = (tau_new - tau[i]).abs() + (nu_new - nu[i]).abs();
1144                max_change = max_change.max(change);
1145
1146                // Update site parameters
1147                tau[i] = tau_new.max(min_variance);
1148                nu[i] = nu_new;
1149
1150                // Update posterior
1151                let tau_diff = tau[i] - tau_cavity;
1152                let nu_diff = nu[i] - mu_cavity * tau_cavity;
1153
1154                if tau_diff.abs() > 1e-12 {
1155                    // Rank-1 update of Sigma
1156                    let si = Sigma.column(i).to_owned();
1157                    let denom = 1.0 + tau_diff * Sigma[[i, i]];
1158
1159                    if denom.abs() > 1e-12 {
1160                        for j in 0..n {
1161                            for k in 0..n {
1162                                Sigma[[j, k]] -= tau_diff * si[j] * si[k] / denom;
1163                            }
1164                        }
1165
1166                        // Update mean
1167                        mu = &mu + (nu_diff / denom) * &si;
1168                    }
1169                }
1170            }
1171        }
1172
1173        if verbose && iter % 10 == 0 {
1174            println!("EP iteration {}: max change = {:.6}", iter, max_change);
1175        }
1176
1177        // Check convergence
1178        if max_change < tol {
1179            converged = true;
1180            if verbose {
1181                println!("EP converged at iteration {}", iter);
1182            }
1183            break;
1184        }
1185    }
1186
1187    if !converged && verbose {
1188        println!("EP did not converge after {} iterations", max_iter);
1189    }
1190
1191    // Compute log marginal likelihood
1192    let log_marginal_likelihood = compute_ep_log_marginal_likelihood(K, &tau, &nu, &mu, &Sigma, y)?;
1193
1194    Ok((mu, Sigma, tau, nu, log_marginal_likelihood, iteration))
1195}
1196
1197/// Compute marginal moments for probit likelihood
1198fn marginal_moments(y: f64, mu: f64, sigma2: f64) -> (f64, f64, f64) {
1199    let sigma = sigma2.sqrt();
1200    let z = y * mu / sigma;
1201
1202    // Compute PDF and CDF of standard normal
1203    let pdf = (-0.5 * z * z).exp() / (2.0 * PI).sqrt();
1204    let cdf = normal_cdf(z);
1205
1206    // Avoid numerical issues
1207    let z0 = cdf.max(1e-12);
1208    let ratio = if z0 > 1e-12 { pdf / z0 } else { 0.0 };
1209
1210    let z1 = mu + y * sigma * ratio;
1211    let z2 = mu * mu + sigma2 * (1.0 - z * ratio - ratio * ratio);
1212
1213    (z0, z1, z2.max(min_variance_global()))
1214}
1215
1216/// Standard normal CDF approximation
1217fn normal_cdf(x: f64) -> f64 {
1218    0.5 * (1.0 + erf(x / 2.0_f64.sqrt()))
1219}
1220
1221/// Error function approximation
1222fn erf(x: f64) -> f64 {
1223    // Abramowitz and Stegun approximation
1224    let a1 = 0.254829592;
1225    let a2 = -0.284496736;
1226    let a3 = 1.421413741;
1227    let a4 = -1.453152027;
1228    let a5 = 1.061405429;
1229    let p = 0.3275911;
1230
1231    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
1232    let x = x.abs();
1233
1234    let t = 1.0 / (1.0 + p * x);
1235    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
1236
1237    sign * y
1238}
1239
1240/// Global minimum variance constant
1241fn min_variance_global() -> f64 {
1242    1e-10
1243}
1244
1245/// Predict with EP posterior
1246fn ep_predict(
1247    K_star: &Array2<f64>,
1248    mu: &Array1<f64>,
1249    Sigma: &Array2<f64>,
1250) -> SklResult<(Array1<f64>, Array1<f64>)> {
1251    // Predictive mean: K_star^T * Sigma^{-1} * mu
1252    let Sigma_inv = utils::matrix_inverse(Sigma)?;
1253    let mean = K_star.t().dot(&Sigma_inv.dot(mu));
1254
1255    // Predictive variance: K_star_star - K_star^T * Sigma^{-1} * K_star
1256    let temp = K_star.t().dot(&Sigma_inv);
1257    let var_reduction = temp.dot(K_star);
1258
1259    // Diagonal of var_reduction gives the variance reduction for each test point
1260    let variance = Array1::from_iter((0..K_star.ncols()).map(|i| {
1261        1.0 - var_reduction[[i, i]] // Assuming K_star_star diagonal is 1
1262    }));
1263
1264    Ok((mean, variance))
1265}
1266
1267/// Compute log marginal likelihood for EP
1268fn compute_ep_log_marginal_likelihood(
1269    K: &Array2<f64>,
1270    tau: &Array1<f64>,
1271    nu: &Array1<f64>,
1272    mu: &Array1<f64>,
1273    Sigma: &Array2<f64>,
1274    _y: &Array1<f64>,
1275) -> SklResult<f64> {
1276    let n = K.nrows();
1277
1278    // Log determinant of posterior covariance
1279    let L_Sigma = utils::robust_cholesky(Sigma)?;
1280    let log_det_Sigma = 2.0 * L_Sigma.diag().mapv(|x| x.ln()).sum();
1281
1282    // Log determinant of prior covariance
1283    let L_K = utils::robust_cholesky(K)?;
1284    let log_det_K = 2.0 * L_K.diag().mapv(|x| x.ln()).sum();
1285
1286    // Quadratic terms
1287    let quad_prior = 0.5 * mu.dot(&utils::triangular_solve(&L_K, mu)?);
1288    let quad_posterior = 0.5 * mu.dot(&utils::triangular_solve(&L_Sigma, mu)?);
1289
1290    // Site contributions
1291    let mut site_contrib = 0.0;
1292    for i in 0..n {
1293        if tau[i] > 1e-12 {
1294            let mu_i = nu[i] / tau[i];
1295            site_contrib += 0.5 * (tau[i].ln() - (2.0 * PI).ln()) - 0.5 * tau[i] * mu_i * mu_i;
1296        }
1297    }
1298
1299    // Approximate marginal likelihood
1300    let log_ml =
1301        -0.5 * log_det_Sigma + 0.5 * log_det_K + site_contrib + quad_prior - quad_posterior;
1302
1303    Ok(log_ml)
1304}