1use std::collections::HashSet;
11use std::f64::consts::PI;
12
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
16use sklears_core::{
17 error::{Result as SklResult, SklearsError},
18 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
19};
20
21use crate::kernels::Kernel;
22use crate::utils;
23
24#[derive(Debug, Clone)]
34pub struct GaussianProcessClassifier<S = Untrained> {
35 state: S,
36 kernel: Option<Box<dyn Kernel>>,
37 optimizer: Option<String>,
38 n_restarts_optimizer: usize,
39 max_iter_predict: usize,
40 warm_start: bool,
41 copy_x_train: bool,
42 random_state: Option<u64>,
43 config: GpcConfig,
44}
45
46#[derive(Debug, Clone)]
48pub struct GpcTrained {
49 pub X_train: Option<Array2<f64>>, pub y_train: Array1<i32>, pub classes: Array1<i32>, pub pi: Array1<f64>, pub W_sr: Array1<f64>, pub L: Array2<f64>, pub K: Array2<f64>, pub f: Array1<f64>, pub kernel: Box<dyn Kernel>, pub log_marginal_likelihood_value: f64, }
70
71impl GaussianProcessClassifier<Untrained> {
72 pub fn new() -> Self {
74 Self {
75 state: Untrained,
76 kernel: None,
77 optimizer: Some("fmin_l_bfgs_b".to_string()),
78 n_restarts_optimizer: 0,
79 max_iter_predict: 100,
80 warm_start: false,
81 copy_x_train: true,
82 random_state: None,
83 config: GpcConfig::default(),
84 }
85 }
86
87 pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
89 self.kernel = Some(kernel);
90 self
91 }
92
93 pub fn optimizer(mut self, optimizer: Option<String>) -> Self {
95 self.optimizer = optimizer;
96 self
97 }
98
99 pub fn n_restarts_optimizer(mut self, n_restarts: usize) -> Self {
101 self.n_restarts_optimizer = n_restarts;
102 self
103 }
104
105 pub fn max_iter_predict(mut self, max_iter: usize) -> Self {
107 self.max_iter_predict = max_iter;
108 self
109 }
110
111 pub fn warm_start(mut self, warm_start: bool) -> Self {
113 self.warm_start = warm_start;
114 self
115 }
116
117 pub fn copy_x_train(mut self, copy_x_train: bool) -> Self {
119 self.copy_x_train = copy_x_train;
120 self
121 }
122
123 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
125 self.random_state = random_state;
126 self
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct GpcConfig {
133 pub kernel_name: String,
135 pub optimizer: Option<String>,
137 pub n_restarts_optimizer: usize,
139 pub max_iter_predict: usize,
141 pub warm_start: bool,
143 pub copy_x_train: bool,
145 pub random_state: Option<u64>,
147}
148
149impl Default for GpcConfig {
150 fn default() -> Self {
151 Self {
152 kernel_name: "RBF".to_string(),
153 optimizer: Some("fmin_l_bfgs_b".to_string()),
154 n_restarts_optimizer: 0,
155 max_iter_predict: 100,
156 warm_start: false,
157 copy_x_train: true,
158 random_state: None,
159 }
160 }
161}
162
163impl Estimator for GaussianProcessClassifier<Untrained> {
164 type Config = GpcConfig;
165 type Error = SklearsError;
166 type Float = f64;
167
168 fn config(&self) -> &Self::Config {
169 &self.config
170 }
171}
172
173impl Estimator for GaussianProcessClassifier<GpcTrained> {
174 type Config = GpcConfig;
175 type Error = SklearsError;
176 type Float = f64;
177
178 fn config(&self) -> &Self::Config {
179 &self.config
180 }
181}
182
183impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>> for GaussianProcessClassifier<Untrained> {
184 type Fitted = GaussianProcessClassifier<GpcTrained>;
185
186 #[allow(non_snake_case)]
187 fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<Self::Fitted> {
188 if X.nrows() != y.len() {
189 return Err(SklearsError::InvalidInput(
190 "X and y must have the same number of samples".to_string(),
191 ));
192 }
193
194 let kernel = self
195 .kernel
196 .as_ref()
197 .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?
198 .clone();
199
200 let mut classes_set: HashSet<i32> = HashSet::new();
202 for &label in y.iter() {
203 classes_set.insert(label);
204 }
205
206 if classes_set.len() != 2 {
207 return Err(SklearsError::InvalidInput(
208 "Binary classification requires exactly 2 classes".to_string(),
209 ));
210 }
211
212 let mut classes: Vec<i32> = classes_set.into_iter().collect();
213 classes.sort();
214 let classes = Array1::from(classes);
215
216 let y_binary = y.mapv(|label| if label == classes[0] { -1.0 } else { 1.0 });
218
219 let X_owned = X.to_owned();
221 let K = kernel.compute_kernel_matrix(&X_owned, None)?;
222
223 let (f, pi, W_sr, L, log_marginal_likelihood_value) =
225 laplace_approximation(&K, &y_binary, self.max_iter_predict)?;
226
227 let X_train = if self.copy_x_train {
228 Some(X.to_owned())
229 } else {
230 None
231 };
232
233 Ok(GaussianProcessClassifier {
234 state: GpcTrained {
235 X_train,
236 y_train: y.to_owned(),
237 classes,
238 pi,
239 W_sr,
240 L,
241 K,
242 f,
243 kernel,
244 log_marginal_likelihood_value,
245 },
246 kernel: None,
247 optimizer: self.optimizer,
248 n_restarts_optimizer: self.n_restarts_optimizer,
249 max_iter_predict: self.max_iter_predict,
250 warm_start: self.warm_start,
251 copy_x_train: self.copy_x_train,
252 random_state: self.random_state,
253 config: self.config,
254 })
255 }
256}
257
258impl Predict<ArrayView2<'_, f64>, Array1<i32>> for GaussianProcessClassifier<GpcTrained> {
259 fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<i32>> {
260 let probabilities = self.predict_proba(X)?;
261 let predictions: Vec<i32> = probabilities
262 .axis_iter(Axis(0))
263 .map(|row| {
264 let max_idx = row
265 .iter()
266 .enumerate()
267 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
268 .map(|(idx, _)| idx)
269 .unwrap();
270 self.state.classes[max_idx]
271 })
272 .collect();
273 Ok(Array1::from(predictions))
274 }
275}
276
277impl PredictProba<ArrayView2<'_, f64>, Array2<f64>> for GaussianProcessClassifier<GpcTrained> {
278 #[allow(non_snake_case)]
279 fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
280 let X_train =
281 self.state.X_train.as_ref().ok_or_else(|| {
282 SklearsError::InvalidInput("Training data not available".to_string())
283 })?;
284
285 let X_test_owned = X.to_owned();
287 let K_star = self
288 .state
289 .kernel
290 .compute_kernel_matrix(X_train, Some(&X_test_owned))?;
291
292 let f_star =
294 predict_latent_function(&K_star, &self.state.f, &self.state.W_sr, &self.state.L)?;
295
296 let mut probabilities = Array2::<f64>::zeros((X.nrows(), 2));
298 for (i, &f_val) in f_star.iter().enumerate() {
299 let prob_positive = sigmoid(f_val);
300 probabilities[[i, 0]] = 1.0 - prob_positive; probabilities[[i, 1]] = prob_positive; }
303
304 Ok(probabilities)
305 }
306}
307
308impl GaussianProcessClassifier<GpcTrained> {
309 pub fn log_marginal_likelihood(&self) -> f64 {
311 self.state.log_marginal_likelihood_value
312 }
313
314 pub fn classes(&self) -> &Array1<i32> {
316 &self.state.classes
317 }
318}
319
320pub fn sigmoid(x: f64) -> f64 {
322 1.0 / (1.0 + (-x).exp())
323}
324
325pub fn sigmoid_derivative(x: f64) -> f64 {
327 let s = sigmoid(x);
328 s * (1.0 - s)
329}
330
331#[allow(non_snake_case)]
333fn laplace_approximation(
334 K: &Array2<f64>,
335 y: &Array1<f64>,
336 max_iter: usize,
337) -> SklResult<(Array1<f64>, Array1<f64>, Array1<f64>, Array2<f64>, f64)> {
338 let n = K.nrows();
339 let mut f = Array1::<f64>::zeros(n);
340 let tol = 1e-6;
341
342 for _iter in 0..max_iter {
343 let pi = f.mapv(sigmoid);
345 let W = f.mapv(sigmoid_derivative);
346 let _W_sr = W.mapv(|w| w.sqrt());
347
348 let grad = &pi - y;
350
351 let mut K_W = K.clone();
354 for i in 0..n {
355 K_W[[i, i]] += W[i];
356 }
357
358 let L = utils::robust_cholesky(&K_W)?;
359 let delta_f = utils::triangular_solve(&L, &grad)?;
360
361 let f_new = &f - &delta_f;
362
363 let diff = (&f_new - &f).mapv(|x| x.abs()).sum();
365 f = f_new;
366
367 if diff < tol {
368 break;
369 }
370 }
371
372 let pi = f.mapv(sigmoid);
374 let W = f.mapv(sigmoid_derivative);
375 let W_sr = W.mapv(|w| w.sqrt());
376
377 let mut K_W = K.clone();
379 for i in 0..n {
380 K_W[[i, i]] += W[i];
381 }
382 let L = utils::robust_cholesky(&K_W)?;
383
384 let log_marginal_likelihood = {
386 let log_det = 2.0 * L.diag().mapv(|x| x.ln()).sum();
387 let quadratic: f64 = f
388 .iter()
389 .zip(y.iter())
390 .map(|(&f_i, &y_i)| y_i * f_i - (1.0 + f_i.exp()).ln())
391 .sum();
392 quadratic - 0.5 * log_det
393 };
394
395 Ok((f, pi, W_sr, L, log_marginal_likelihood))
396}
397
398fn predict_latent_function(
400 K_star: &Array2<f64>,
401 f_train: &Array1<f64>,
402 _W_sr: &Array1<f64>,
403 _L: &Array2<f64>,
404) -> SklResult<Array1<f64>> {
405 let f_star_values = K_star.dot(f_train);
407 Ok(f_star_values)
408}
409
410impl Default for GaussianProcessClassifier<Untrained> {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[derive(Debug, Clone)]
440pub struct MultiClassGaussianProcessClassifier<S = Untrained> {
441 state: S,
442 kernel: Option<Box<dyn Kernel>>,
443 optimizer: Option<String>,
444 n_restarts_optimizer: usize,
445 max_iter_predict: usize,
446 warm_start: bool,
447 copy_x_train: bool,
448 random_state: Option<u64>,
449 config: GpcConfig,
450}
451
452#[derive(Debug, Clone)]
454pub struct McGpcTrained {
455 pub X_train: Option<Array2<f64>>, pub y_train: Array1<i32>, pub classes: Array1<i32>, pub binary_classifiers: Vec<GaussianProcessClassifier<GpcTrained>>, pub n_classes: usize, }
466
467impl MultiClassGaussianProcessClassifier<Untrained> {
468 pub fn new() -> Self {
470 Self {
471 state: Untrained,
472 kernel: None,
473 optimizer: Some("fmin_l_bfgs_b".to_string()),
474 n_restarts_optimizer: 0,
475 max_iter_predict: 100,
476 warm_start: false,
477 copy_x_train: true,
478 random_state: None,
479 config: GpcConfig::default(),
480 }
481 }
482
483 pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
485 self.kernel = Some(kernel);
486 self
487 }
488
489 pub fn optimizer(mut self, optimizer: Option<String>) -> Self {
491 self.optimizer = optimizer;
492 self
493 }
494
495 pub fn n_restarts_optimizer(mut self, n_restarts: usize) -> Self {
497 self.n_restarts_optimizer = n_restarts;
498 self
499 }
500
501 pub fn max_iter_predict(mut self, max_iter: usize) -> Self {
503 self.max_iter_predict = max_iter;
504 self
505 }
506
507 pub fn warm_start(mut self, warm_start: bool) -> Self {
509 self.warm_start = warm_start;
510 self
511 }
512
513 pub fn copy_x_train(mut self, copy_x_train: bool) -> Self {
515 self.copy_x_train = copy_x_train;
516 self
517 }
518
519 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
521 self.random_state = random_state;
522 self
523 }
524}
525
526impl Estimator for MultiClassGaussianProcessClassifier<Untrained> {
527 type Config = GpcConfig;
528 type Error = SklearsError;
529 type Float = f64;
530
531 fn config(&self) -> &Self::Config {
532 &self.config
533 }
534}
535
536impl Estimator for MultiClassGaussianProcessClassifier<McGpcTrained> {
537 type Config = GpcConfig;
538 type Error = SklearsError;
539 type Float = f64;
540
541 fn config(&self) -> &Self::Config {
542 &self.config
543 }
544}
545
546impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>>
547 for MultiClassGaussianProcessClassifier<Untrained>
548{
549 type Fitted = MultiClassGaussianProcessClassifier<McGpcTrained>;
550
551 #[allow(non_snake_case)]
552 fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<Self::Fitted> {
553 let kernel = self
554 .kernel
555 .as_ref()
556 .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?
557 .clone();
558
559 let mut classes_set: HashSet<i32> = HashSet::new();
561 for &label in y.iter() {
562 classes_set.insert(label);
563 }
564
565 if classes_set.len() < 2 {
566 return Err(SklearsError::InvalidInput(
567 "Multi-class classification requires at least 2 classes".to_string(),
568 ));
569 }
570
571 let mut classes: Vec<i32> = classes_set.into_iter().collect();
572 classes.sort();
573 let classes = Array1::from(classes);
574 let n_classes = classes.len();
575
576 if n_classes == 2 {
578 let binary_gpc = GaussianProcessClassifier::new()
579 .kernel(kernel)
580 .optimizer(self.optimizer.clone())
581 .n_restarts_optimizer(self.n_restarts_optimizer)
582 .max_iter_predict(self.max_iter_predict)
583 .warm_start(self.warm_start)
584 .copy_x_train(self.copy_x_train)
585 .random_state(self.random_state);
586
587 let fitted_binary = binary_gpc.fit(X, y)?;
588
589 let X_train = if self.copy_x_train {
590 Some(X.to_owned())
591 } else {
592 None
593 };
594
595 return Ok(MultiClassGaussianProcessClassifier {
596 state: McGpcTrained {
597 X_train,
598 y_train: y.to_owned(),
599 classes,
600 binary_classifiers: vec![fitted_binary],
601 n_classes,
602 },
603 kernel: None,
604 optimizer: self.optimizer.clone(),
605 n_restarts_optimizer: self.n_restarts_optimizer,
606 max_iter_predict: self.max_iter_predict,
607 warm_start: self.warm_start,
608 copy_x_train: self.copy_x_train,
609 random_state: self.random_state,
610 config: self.config.clone(),
611 });
612 }
613
614 let mut binary_classifiers = Vec::with_capacity(n_classes);
616
617 for (class_idx, ¤t_class) in classes.iter().enumerate() {
618 let y_binary: Array1<i32> = y.mapv(|label| if label == current_class { 1 } else { 0 });
620
621 let binary_gpc = GaussianProcessClassifier::new()
623 .kernel(kernel.clone())
624 .optimizer(self.optimizer.clone())
625 .n_restarts_optimizer(self.n_restarts_optimizer)
626 .max_iter_predict(self.max_iter_predict)
627 .warm_start(self.warm_start)
628 .copy_x_train(self.copy_x_train)
629 .random_state(self.random_state.map(|s| s + class_idx as u64));
630
631 let fitted_binary = binary_gpc.fit(X, &y_binary.view())?;
632 binary_classifiers.push(fitted_binary);
633 }
634
635 let X_train = if self.copy_x_train {
636 Some(X.to_owned())
637 } else {
638 None
639 };
640
641 Ok(MultiClassGaussianProcessClassifier {
642 state: McGpcTrained {
643 X_train,
644 y_train: y.to_owned(),
645 classes,
646 binary_classifiers,
647 n_classes,
648 },
649 kernel: None,
650 optimizer: self.optimizer,
651 n_restarts_optimizer: self.n_restarts_optimizer,
652 max_iter_predict: self.max_iter_predict,
653 warm_start: self.warm_start,
654 copy_x_train: self.copy_x_train,
655 random_state: self.random_state,
656 config: self.config.clone(),
657 })
658 }
659}
660
661impl Predict<ArrayView2<'_, f64>, Array1<i32>>
662 for MultiClassGaussianProcessClassifier<McGpcTrained>
663{
664 fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<i32>> {
665 let probabilities = self.predict_proba(X)?;
666 let predictions: Vec<i32> = probabilities
667 .axis_iter(Axis(0))
668 .map(|row| {
669 let max_idx = row
670 .iter()
671 .enumerate()
672 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
673 .map(|(idx, _)| idx)
674 .unwrap();
675 self.state.classes[max_idx]
676 })
677 .collect();
678 Ok(Array1::from(predictions))
679 }
680}
681
682impl PredictProba<ArrayView2<'_, f64>, Array2<f64>>
683 for MultiClassGaussianProcessClassifier<McGpcTrained>
684{
685 fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
686 let n_samples = X.nrows();
687 let n_classes = self.state.n_classes;
688
689 if n_classes == 2 {
691 return self.state.binary_classifiers[0].predict_proba(X);
692 }
693
694 let mut all_probabilities = Array2::<f64>::zeros((n_samples, n_classes));
696
697 for (class_idx, binary_classifier) in self.state.binary_classifiers.iter().enumerate() {
698 let binary_proba = binary_classifier.predict_proba(X)?;
699 for i in 0..n_samples {
701 all_probabilities[[i, class_idx]] = binary_proba[[i, 1]];
702 }
703 }
704
705 for i in 0..n_samples {
707 let row_sum: f64 = all_probabilities.row(i).sum();
708 if row_sum > 1e-12 {
709 for j in 0..n_classes {
710 all_probabilities[[i, j]] /= row_sum;
711 }
712 } else {
713 for j in 0..n_classes {
715 all_probabilities[[i, j]] = 1.0 / n_classes as f64;
716 }
717 }
718 }
719
720 Ok(all_probabilities)
721 }
722}
723
724impl MultiClassGaussianProcessClassifier<McGpcTrained> {
725 pub fn classes(&self) -> &Array1<i32> {
727 &self.state.classes
728 }
729
730 pub fn n_classes(&self) -> usize {
732 self.state.n_classes
733 }
734
735 pub fn binary_classifiers(&self) -> &[GaussianProcessClassifier<GpcTrained>] {
737 &self.state.binary_classifiers
738 }
739
740 pub fn log_marginal_likelihood(&self, class_idx: usize) -> Option<f64> {
742 if class_idx < self.state.binary_classifiers.len() {
743 Some(self.state.binary_classifiers[class_idx].log_marginal_likelihood())
744 } else {
745 None
746 }
747 }
748
749 pub fn average_log_marginal_likelihood(&self) -> f64 {
751 let sum: f64 = self
752 .state
753 .binary_classifiers
754 .iter()
755 .map(|classifier| classifier.log_marginal_likelihood())
756 .sum();
757 sum / self.state.binary_classifiers.len() as f64
758 }
759}
760
761impl Default for MultiClassGaussianProcessClassifier<Untrained> {
762 fn default() -> Self {
763 Self::new()
764 }
765}
766
767#[derive(Debug, Clone)]
794pub struct ExpectationPropagationGaussianProcessClassifier<S = Untrained> {
795 state: S,
796 kernel: Option<Box<dyn Kernel>>,
797 max_iter: usize,
798 tol: f64,
799 damping: f64,
800 min_variance: f64,
801 verbose: bool,
802 random_state: Option<u64>,
803 config: GpcConfig,
804}
805
806#[derive(Debug, Clone)]
808pub struct EpGpcTrained {
809 pub X_train: Option<Array2<f64>>, pub y_train: Array1<i32>, pub classes: Array1<i32>, pub mu: Array1<f64>, pub Sigma: Array2<f64>, pub tau: Array1<f64>, pub nu: Array1<f64>, pub kernel: Box<dyn Kernel>, pub log_marginal_likelihood_value: f64, pub n_iterations: usize, }
830
831impl ExpectationPropagationGaussianProcessClassifier<Untrained> {
832 pub fn new() -> Self {
834 Self {
835 state: Untrained,
836 kernel: None,
837 max_iter: 100,
838 tol: 1e-4,
839 damping: 0.5,
840 min_variance: 1e-10,
841 verbose: false,
842 random_state: None,
843 config: GpcConfig::default(),
844 }
845 }
846
847 pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
849 self.kernel = Some(kernel);
850 self
851 }
852
853 pub fn max_iter(mut self, max_iter: usize) -> Self {
855 self.max_iter = max_iter;
856 self
857 }
858
859 pub fn tol(mut self, tol: f64) -> Self {
861 self.tol = tol;
862 self
863 }
864
865 pub fn damping(mut self, damping: f64) -> Self {
867 self.damping = damping.max(0.0).min(1.0);
868 self
869 }
870
871 pub fn min_variance(mut self, min_variance: f64) -> Self {
873 self.min_variance = min_variance;
874 self
875 }
876
877 pub fn verbose(mut self, verbose: bool) -> Self {
879 self.verbose = verbose;
880 self
881 }
882
883 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
885 self.random_state = random_state;
886 self
887 }
888}
889
890impl Estimator for ExpectationPropagationGaussianProcessClassifier<Untrained> {
891 type Config = GpcConfig;
892 type Error = SklearsError;
893 type Float = f64;
894
895 fn config(&self) -> &Self::Config {
896 &self.config
897 }
898}
899
900impl Estimator for ExpectationPropagationGaussianProcessClassifier<EpGpcTrained> {
901 type Config = GpcConfig;
902 type Error = SklearsError;
903 type Float = f64;
904
905 fn config(&self) -> &Self::Config {
906 &self.config
907 }
908}
909
910impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, i32>>
911 for ExpectationPropagationGaussianProcessClassifier<Untrained>
912{
913 type Fitted = ExpectationPropagationGaussianProcessClassifier<EpGpcTrained>;
914
915 #[allow(non_snake_case)]
916 fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<Self::Fitted> {
917 if X.nrows() != y.len() {
918 return Err(SklearsError::InvalidInput(
919 "X and y must have the same number of samples".to_string(),
920 ));
921 }
922
923 let kernel = self
924 .kernel
925 .as_ref()
926 .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?
927 .clone();
928
929 let mut classes_set: HashSet<i32> = HashSet::new();
931 for &label in y.iter() {
932 classes_set.insert(label);
933 }
934
935 if classes_set.len() != 2 {
936 return Err(SklearsError::InvalidInput(
937 "Binary classification requires exactly 2 classes".to_string(),
938 ));
939 }
940
941 let mut classes: Vec<i32> = classes_set.into_iter().collect();
942 classes.sort();
943 let classes = Array1::from(classes);
944
945 let y_binary = y.mapv(|label| if label == classes[0] { -1.0 } else { 1.0 });
947
948 let X_owned = X.to_owned();
950 let K = kernel.compute_kernel_matrix(&X_owned, None)?;
951
952 let (mu, Sigma, tau, nu, log_marginal_likelihood_value, n_iterations) =
954 expectation_propagation(
955 &K,
956 &y_binary,
957 self.max_iter,
958 self.tol,
959 self.damping,
960 self.min_variance,
961 self.verbose,
962 )?;
963
964 let X_train = Some(X.to_owned());
965
966 Ok(ExpectationPropagationGaussianProcessClassifier {
967 state: EpGpcTrained {
968 X_train,
969 y_train: y.to_owned(),
970 classes,
971 mu,
972 Sigma,
973 tau,
974 nu,
975 kernel,
976 log_marginal_likelihood_value,
977 n_iterations,
978 },
979 kernel: None,
980 max_iter: self.max_iter,
981 tol: self.tol,
982 damping: self.damping,
983 min_variance: self.min_variance,
984 verbose: self.verbose,
985 random_state: self.random_state,
986 config: self.config.clone(),
987 })
988 }
989}
990
991impl Predict<ArrayView2<'_, f64>, Array1<i32>>
992 for ExpectationPropagationGaussianProcessClassifier<EpGpcTrained>
993{
994 fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<i32>> {
995 let probabilities = self.predict_proba(X)?;
996 let predictions: Vec<i32> = probabilities
997 .axis_iter(Axis(0))
998 .map(|row| {
999 let max_idx = row
1000 .iter()
1001 .enumerate()
1002 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1003 .map(|(idx, _)| idx)
1004 .unwrap();
1005 self.state.classes[max_idx]
1006 })
1007 .collect();
1008 Ok(Array1::from(predictions))
1009 }
1010}
1011
1012impl PredictProba<ArrayView2<'_, f64>, Array2<f64>>
1013 for ExpectationPropagationGaussianProcessClassifier<EpGpcTrained>
1014{
1015 #[allow(non_snake_case)]
1016 fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
1017 let X_train =
1018 self.state.X_train.as_ref().ok_or_else(|| {
1019 SklearsError::InvalidInput("Training data not available".to_string())
1020 })?;
1021
1022 let X_test_owned = X.to_owned();
1024 let K_star = self
1025 .state
1026 .kernel
1027 .compute_kernel_matrix(X_train, Some(&X_test_owned))?;
1028
1029 let (f_star_mean, f_star_var) = ep_predict(&K_star, &self.state.mu, &self.state.Sigma)?;
1031
1032 let mut probabilities = Array2::<f64>::zeros((X.nrows(), 2));
1034 for (i, (&mean, &var)) in f_star_mean.iter().zip(f_star_var.iter()).enumerate() {
1035 let std_dev = var.sqrt().max(1e-10);
1036 let z = mean / std_dev;
1037 let prob_positive = normal_cdf(z);
1038 probabilities[[i, 0]] = 1.0 - prob_positive; probabilities[[i, 1]] = prob_positive; }
1041
1042 Ok(probabilities)
1043 }
1044}
1045
1046impl ExpectationPropagationGaussianProcessClassifier<EpGpcTrained> {
1047 pub fn log_marginal_likelihood(&self) -> f64 {
1049 self.state.log_marginal_likelihood_value
1050 }
1051
1052 pub fn classes(&self) -> &Array1<i32> {
1054 &self.state.classes
1055 }
1056
1057 pub fn posterior_mean(&self) -> &Array1<f64> {
1059 &self.state.mu
1060 }
1061
1062 pub fn posterior_covariance(&self) -> &Array2<f64> {
1064 &self.state.Sigma
1065 }
1066
1067 pub fn n_iterations(&self) -> usize {
1069 self.state.n_iterations
1070 }
1071}
1072
1073impl Default for ExpectationPropagationGaussianProcessClassifier<Untrained> {
1074 fn default() -> Self {
1075 Self::new()
1076 }
1077}
1078
1079fn expectation_propagation(
1083 K: &Array2<f64>,
1084 y: &Array1<f64>,
1085 max_iter: usize,
1086 tol: f64,
1087 damping: f64,
1088 min_variance: f64,
1089 verbose: bool,
1090) -> SklResult<(
1091 Array1<f64>,
1092 Array2<f64>,
1093 Array1<f64>,
1094 Array1<f64>,
1095 f64,
1096 usize,
1097)> {
1098 let n = K.nrows();
1099
1100 let mut tau = Array1::<f64>::zeros(n); let mut nu = Array1::<f64>::zeros(n); let mut mu = Array1::<f64>::zeros(n); let mut Sigma = K.clone(); let mut converged = false;
1109 let mut iteration = 0;
1110
1111 for iter in 0..max_iter {
1112 iteration = iter + 1;
1113 let mut max_change: f64 = 0.0;
1114
1115 for i in 0..n {
1117 let tau_cavity = 1.0 / Sigma[[i, i]] - tau[i];
1119 let mu_cavity = if tau_cavity > 1e-12 {
1120 mu[i] / (tau_cavity * Sigma[[i, i]]) - nu[i] / tau_cavity
1121 } else {
1122 0.0
1123 };
1124
1125 let sigma_cavity = if tau_cavity > 1e-12 {
1127 1.0 / tau_cavity
1128 } else {
1129 1e6
1130 };
1131 let (z0, z1, z2) = marginal_moments(y[i], mu_cavity, sigma_cavity);
1132
1133 if z0 > 1e-12 {
1134 let delta_tau = z1 / sigma_cavity - z2 / (sigma_cavity * sigma_cavity) - tau_cavity;
1136 let delta_nu = z1 / sigma_cavity - mu_cavity * tau_cavity;
1137
1138 let tau_new = tau[i] + damping * delta_tau;
1140 let nu_new = nu[i] + damping * delta_nu;
1141
1142 let change = (tau_new - tau[i]).abs() + (nu_new - nu[i]).abs();
1144 max_change = max_change.max(change);
1145
1146 tau[i] = tau_new.max(min_variance);
1148 nu[i] = nu_new;
1149
1150 let tau_diff = tau[i] - tau_cavity;
1152 let nu_diff = nu[i] - mu_cavity * tau_cavity;
1153
1154 if tau_diff.abs() > 1e-12 {
1155 let si = Sigma.column(i).to_owned();
1157 let denom = 1.0 + tau_diff * Sigma[[i, i]];
1158
1159 if denom.abs() > 1e-12 {
1160 for j in 0..n {
1161 for k in 0..n {
1162 Sigma[[j, k]] -= tau_diff * si[j] * si[k] / denom;
1163 }
1164 }
1165
1166 mu = &mu + (nu_diff / denom) * &si;
1168 }
1169 }
1170 }
1171 }
1172
1173 if verbose && iter % 10 == 0 {
1174 println!("EP iteration {}: max change = {:.6}", iter, max_change);
1175 }
1176
1177 if max_change < tol {
1179 converged = true;
1180 if verbose {
1181 println!("EP converged at iteration {}", iter);
1182 }
1183 break;
1184 }
1185 }
1186
1187 if !converged && verbose {
1188 println!("EP did not converge after {} iterations", max_iter);
1189 }
1190
1191 let log_marginal_likelihood = compute_ep_log_marginal_likelihood(K, &tau, &nu, &mu, &Sigma, y)?;
1193
1194 Ok((mu, Sigma, tau, nu, log_marginal_likelihood, iteration))
1195}
1196
1197fn marginal_moments(y: f64, mu: f64, sigma2: f64) -> (f64, f64, f64) {
1199 let sigma = sigma2.sqrt();
1200 let z = y * mu / sigma;
1201
1202 let pdf = (-0.5 * z * z).exp() / (2.0 * PI).sqrt();
1204 let cdf = normal_cdf(z);
1205
1206 let z0 = cdf.max(1e-12);
1208 let ratio = if z0 > 1e-12 { pdf / z0 } else { 0.0 };
1209
1210 let z1 = mu + y * sigma * ratio;
1211 let z2 = mu * mu + sigma2 * (1.0 - z * ratio - ratio * ratio);
1212
1213 (z0, z1, z2.max(min_variance_global()))
1214}
1215
1216fn normal_cdf(x: f64) -> f64 {
1218 0.5 * (1.0 + erf(x / 2.0_f64.sqrt()))
1219}
1220
1221fn erf(x: f64) -> f64 {
1223 let a1 = 0.254829592;
1225 let a2 = -0.284496736;
1226 let a3 = 1.421413741;
1227 let a4 = -1.453152027;
1228 let a5 = 1.061405429;
1229 let p = 0.3275911;
1230
1231 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
1232 let x = x.abs();
1233
1234 let t = 1.0 / (1.0 + p * x);
1235 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
1236
1237 sign * y
1238}
1239
1240fn min_variance_global() -> f64 {
1242 1e-10
1243}
1244
1245fn ep_predict(
1247 K_star: &Array2<f64>,
1248 mu: &Array1<f64>,
1249 Sigma: &Array2<f64>,
1250) -> SklResult<(Array1<f64>, Array1<f64>)> {
1251 let Sigma_inv = utils::matrix_inverse(Sigma)?;
1253 let mean = K_star.t().dot(&Sigma_inv.dot(mu));
1254
1255 let temp = K_star.t().dot(&Sigma_inv);
1257 let var_reduction = temp.dot(K_star);
1258
1259 let variance = Array1::from_iter((0..K_star.ncols()).map(|i| {
1261 1.0 - var_reduction[[i, i]] }));
1263
1264 Ok((mean, variance))
1265}
1266
1267fn compute_ep_log_marginal_likelihood(
1269 K: &Array2<f64>,
1270 tau: &Array1<f64>,
1271 nu: &Array1<f64>,
1272 mu: &Array1<f64>,
1273 Sigma: &Array2<f64>,
1274 _y: &Array1<f64>,
1275) -> SklResult<f64> {
1276 let n = K.nrows();
1277
1278 let L_Sigma = utils::robust_cholesky(Sigma)?;
1280 let log_det_Sigma = 2.0 * L_Sigma.diag().mapv(|x| x.ln()).sum();
1281
1282 let L_K = utils::robust_cholesky(K)?;
1284 let log_det_K = 2.0 * L_K.diag().mapv(|x| x.ln()).sum();
1285
1286 let quad_prior = 0.5 * mu.dot(&utils::triangular_solve(&L_K, mu)?);
1288 let quad_posterior = 0.5 * mu.dot(&utils::triangular_solve(&L_Sigma, mu)?);
1289
1290 let mut site_contrib = 0.0;
1292 for i in 0..n {
1293 if tau[i] > 1e-12 {
1294 let mu_i = nu[i] / tau[i];
1295 site_contrib += 0.5 * (tau[i].ln() - (2.0 * PI).ln()) - 0.5 * tau[i] * mu_i * mu_i;
1296 }
1297 }
1298
1299 let log_ml =
1301 -0.5 * log_det_Sigma + 0.5 * log_det_K + site_contrib + quad_prior - quad_posterior;
1302
1303 Ok(log_ml)
1304}