1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7 error::{Result as SklResult, SklearsError},
8 traits::{Estimator, Fit, Transform, Untrained},
9 types::Float,
10};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
47pub struct GaussianProcessImputer<S = Untrained> {
48 state: S,
49 kernel: String,
50 length_scale: f64,
51 length_scale_bounds: (f64, f64),
52 nu: f64,
53 alpha: f64,
54 optimizer: String,
55 n_restarts_optimizer: usize,
56 missing_values: f64,
57 normalize_y: bool,
58 random_state: Option<u64>,
59}
60
61#[derive(Debug, Clone)]
63pub struct GaussianProcessImputerTrained {
64 X_train_: Array2<f64>,
65 y_train_: HashMap<usize, Array1<f64>>,
66 L_: HashMap<usize, Array2<f64>>, alpha_: HashMap<usize, Array1<f64>>, K_inv_: HashMap<usize, Array2<f64>>, log_marginal_likelihood_: HashMap<usize, f64>,
70 n_features_in_: usize,
71 optimized_kernel_params_: HashMap<usize, (f64, f64)>, }
73
74#[derive(Debug, Clone)]
76pub struct GPPredictionResult {
77 pub mean: f64,
79 pub std: f64,
81 pub confidence_interval_95: (f64, f64),
83}
84
85impl GaussianProcessImputer<Untrained> {
86 pub fn new() -> Self {
88 Self {
89 state: Untrained,
90 kernel: "rbf".to_string(),
91 length_scale: 1.0,
92 length_scale_bounds: (1e-5, 1e5),
93 nu: 1.5,
94 alpha: 1e-10,
95 optimizer: "fmin_l_bfgs_b".to_string(),
96 n_restarts_optimizer: 0,
97 missing_values: f64::NAN,
98 normalize_y: false,
99 random_state: None,
100 }
101 }
102
103 pub fn kernel(mut self, kernel: String) -> Self {
105 self.kernel = kernel;
106 self
107 }
108
109 pub fn length_scale(mut self, length_scale: f64) -> Self {
111 self.length_scale = length_scale;
112 self
113 }
114
115 pub fn length_scale_bounds(mut self, bounds: (f64, f64)) -> Self {
117 self.length_scale_bounds = bounds;
118 self
119 }
120
121 pub fn nu(mut self, nu: f64) -> Self {
123 self.nu = nu;
124 self
125 }
126
127 pub fn alpha(mut self, alpha: f64) -> Self {
129 self.alpha = alpha;
130 self
131 }
132
133 pub fn optimizer(mut self, optimizer: String) -> Self {
135 self.optimizer = optimizer;
136 self
137 }
138
139 pub fn n_restarts_optimizer(mut self, n_restarts: usize) -> Self {
141 self.n_restarts_optimizer = n_restarts;
142 self
143 }
144
145 pub fn missing_values(mut self, missing_values: f64) -> Self {
147 self.missing_values = missing_values;
148 self
149 }
150
151 pub fn normalize_y(mut self, normalize_y: bool) -> Self {
153 self.normalize_y = normalize_y;
154 self
155 }
156
157 pub fn random_state(mut self, random_state: u64) -> Self {
159 self.random_state = Some(random_state);
160 self
161 }
162
163 fn is_missing(&self, value: f64) -> bool {
164 if self.missing_values.is_nan() {
165 value.is_nan()
166 } else {
167 (value - self.missing_values).abs() < f64::EPSILON
168 }
169 }
170
171 fn kernel_function(
172 &self,
173 x1: &ArrayView1<f64>,
174 x2: &ArrayView1<f64>,
175 length_scale: f64,
176 ) -> f64 {
177 match self.kernel.as_str() {
178 "linear" => x1.dot(x2),
179 "rbf" => {
180 let diff = (x1 - x2).mapv(|x| x * x).sum();
181 (-0.5 * diff / (length_scale * length_scale)).exp()
182 }
183 "matern32" => {
184 let r = ((x1 - x2).mapv(|x| x * x).sum().sqrt()) / length_scale;
185 (1.0 + (3.0_f64).sqrt() * r) * (-(3.0_f64).sqrt() * r).exp()
186 }
187 "matern52" => {
188 let r = ((x1 - x2).mapv(|x| x * x).sum().sqrt()) / length_scale;
189 (1.0 + (5.0_f64).sqrt() * r + 5.0 * r * r / 3.0) * (-(5.0_f64).sqrt() * r).exp()
190 }
191 _ => {
192 let diff = (x1 - x2).mapv(|x| x * x).sum();
194 (-0.5 * diff / (length_scale * length_scale)).exp()
195 }
196 }
197 }
198
199 fn compute_kernel_matrix(&self, X: &Array2<f64>, length_scale: f64, alpha: f64) -> Array2<f64> {
200 let n_samples = X.nrows();
201
202 match self.kernel.as_str() {
204 "rbf" | "squared_exponential" => {
205 let mut K = Array2::zeros((n_samples, n_samples));
206
207 for i in 0..n_samples {
209 let x1 = X.row(i);
210 for j in i..n_samples {
211 let x2 = X.row(j);
213 let dist_sq = x1
214 .iter()
215 .zip(x2.iter())
216 .map(|(a, b)| (a - b).powi(2))
217 .sum::<f64>();
218 let kernel_value = (-0.5 * dist_sq / (length_scale * length_scale)).exp();
219 K[[i, j]] = kernel_value;
220 if i != j {
221 K[[j, i]] = kernel_value; }
223 }
224 K[[i, i]] += alpha; }
226 K
227 }
228 _ => {
229 let mut K = Array2::zeros((n_samples, n_samples));
231 for i in 0..n_samples {
232 for j in 0..n_samples {
233 K[[i, j]] = self.kernel_function(&X.row(i), &X.row(j), length_scale);
234 if i == j {
235 K[[i, j]] += alpha;
236 }
237 }
238 }
239 K
240 }
241 }
242 }
243
244 #[allow(non_snake_case)]
245 fn log_marginal_likelihood(
246 &self,
247 X: &Array2<f64>,
248 y: &Array1<f64>,
249 length_scale: f64,
250 alpha: f64,
251 ) -> f64 {
252 let K = self.compute_kernel_matrix(X, length_scale, alpha);
253
254 if let Ok(L) = self.cholesky_decomposition(&K) {
256 let alpha_vec = self.solve_triangular_lower(&L, y);
258
259 let data_fit = -0.5 * y.dot(&alpha_vec);
261 let complexity_penalty = -L.diag().mapv(|x| x.ln()).sum();
262 let normalizing_constant = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
263
264 data_fit + complexity_penalty + normalizing_constant
265 } else {
266 f64::NEG_INFINITY }
268 }
269
270 fn optimize_hyperparameters(&self, X: &Array2<f64>, y: &Array1<f64>) -> (f64, f64) {
271 let mut best_length_scale = self.length_scale;
272 let mut best_alpha = self.alpha;
273 let mut best_likelihood = f64::NEG_INFINITY;
274
275 let length_scales = [0.1, 0.5, 1.0, 2.0, 5.0];
277 let alphas = [1e-10, 1e-8, 1e-6, 1e-4, 1e-2];
278
279 for &ls in &length_scales {
280 for &alpha in &alphas {
281 let likelihood = self.log_marginal_likelihood(X, y, ls, alpha);
282 if likelihood > best_likelihood {
283 best_likelihood = likelihood;
284 best_length_scale = ls;
285 best_alpha = alpha;
286 }
287 }
288 }
289
290 (best_length_scale, best_alpha)
291 }
292
293 fn solve_triangular_lower(&self, L: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
294 let n = L.nrows();
295 let mut y = Array1::zeros(n);
296
297 for i in 0..n {
298 let mut sum = 0.0;
299 for j in 0..i {
300 sum += L[[i, j]] * y[j];
301 }
302 y[i] = (b[i] - sum) / L[[i, i]];
303 }
304
305 y
306 }
307
308 fn solve_triangular_upper(&self, U: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
309 let n = U.nrows();
310 let mut x = Array1::zeros(n);
311
312 for i in (0..n).rev() {
313 let mut sum = 0.0;
314 for j in (i + 1)..n {
315 sum += U[[i, j]] * x[j];
316 }
317 x[i] = (b[i] - sum) / U[[i, i]];
318 }
319
320 x
321 }
322
323 fn cholesky_decomposition(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
324 let n = A.nrows();
325 let mut L = Array2::zeros((n, n));
326
327 for i in 0..n {
328 for j in 0..=i {
329 if i == j {
330 let mut sum = 0.0;
332 for k in 0..j {
333 sum += L[[j, k]] * L[[j, k]];
334 }
335 let val = A[[j, j]] - sum;
336 if val <= 0.0 {
337 return Err(SklearsError::InvalidInput(
338 "Matrix is not positive definite".to_string(),
339 ));
340 }
341 L[[j, j]] = val.sqrt();
342 } else {
343 let mut sum = 0.0;
345 for k in 0..j {
346 sum += L[[i, k]] * L[[j, k]];
347 }
348 L[[i, j]] = (A[[i, j]] - sum) / L[[j, j]];
349 }
350 }
351 }
352
353 Ok(L)
354 }
355}
356
357impl Default for GaussianProcessImputer<Untrained> {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363impl Estimator for GaussianProcessImputer<Untrained> {
364 type Config = ();
365 type Error = SklearsError;
366 type Float = Float;
367
368 fn config(&self) -> &Self::Config {
369 &()
370 }
371}
372
373impl Fit<ArrayView2<'_, Float>, ()> for GaussianProcessImputer<Untrained> {
374 type Fitted = GaussianProcessImputer<GaussianProcessImputerTrained>;
375
376 #[allow(non_snake_case)]
377 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
378 let X = X.mapv(|x| x);
379 let (n_samples, n_features) = X.dim();
380
381 let mut complete_rows = Vec::new();
383 for i in 0..n_samples {
384 let mut is_complete = true;
385 for j in 0..n_features {
386 if self.is_missing(X[[i, j]]) {
387 is_complete = false;
388 break;
389 }
390 }
391 if is_complete {
392 complete_rows.push(i);
393 }
394 }
395
396 if complete_rows.is_empty() {
397 return Err(SklearsError::InvalidInput(
398 "No complete cases found for training".to_string(),
399 ));
400 }
401
402 let mut X_train = Array2::zeros((complete_rows.len(), n_features));
404 for (new_i, &orig_i) in complete_rows.iter().enumerate() {
405 for j in 0..n_features {
406 X_train[[new_i, j]] = X[[orig_i, j]];
407 }
408 }
409
410 let mut y_train = HashMap::new();
412 let mut L_chol = HashMap::new();
413 let mut alpha_coeffs = HashMap::new();
414 let mut K_inv_matrices = HashMap::new();
415 let mut log_likelihoods = HashMap::new();
416 let mut optimized_params = HashMap::new();
417
418 for target_feature in 0..n_features {
419 let mut X_feat = Array2::zeros((complete_rows.len(), n_features - 1));
421 let mut col_idx = 0;
422 for j in 0..n_features {
423 if j != target_feature {
424 for i in 0..complete_rows.len() {
425 X_feat[[i, col_idx]] = X_train[[i, j]];
426 }
427 col_idx += 1;
428 }
429 }
430
431 let y_target = X_train.column(target_feature).to_owned();
433
434 let (opt_length_scale, opt_alpha) = self.optimize_hyperparameters(&X_feat, &y_target);
436
437 let K = self.compute_kernel_matrix(&X_feat, opt_length_scale, opt_alpha);
439
440 let L = self.cholesky_decomposition(&K)?;
442
443 let alpha_vec = self.solve_triangular_lower(&L, &y_target);
445 let alpha_final = self.solve_triangular_upper(&L.t().to_owned(), &alpha_vec);
446
447 let I = Array2::eye(K.nrows());
449 let mut K_inv = Array2::zeros((K.nrows(), K.ncols()));
450 for i in 0..I.ncols() {
451 let col = I.column(i);
452 let y_temp = self.solve_triangular_lower(&L, &col.to_owned());
453 let x_temp = self.solve_triangular_upper(&L.t().to_owned(), &y_temp);
454 for j in 0..K.nrows() {
455 K_inv[[j, i]] = x_temp[j];
456 }
457 }
458
459 let log_likelihood =
461 self.log_marginal_likelihood(&X_feat, &y_target, opt_length_scale, opt_alpha);
462
463 y_train.insert(target_feature, y_target);
464 L_chol.insert(target_feature, L);
465 alpha_coeffs.insert(target_feature, alpha_final);
466 K_inv_matrices.insert(target_feature, K_inv);
467 log_likelihoods.insert(target_feature, log_likelihood);
468 optimized_params.insert(target_feature, (opt_length_scale, opt_alpha));
469 }
470
471 Ok(GaussianProcessImputer {
472 state: GaussianProcessImputerTrained {
473 X_train_: X_train,
474 y_train_: y_train,
475 L_: L_chol,
476 alpha_: alpha_coeffs,
477 K_inv_: K_inv_matrices,
478 log_marginal_likelihood_: log_likelihoods,
479 n_features_in_: n_features,
480 optimized_kernel_params_: optimized_params,
481 },
482 kernel: self.kernel,
483 length_scale: self.length_scale,
484 length_scale_bounds: self.length_scale_bounds,
485 nu: self.nu,
486 alpha: self.alpha,
487 optimizer: self.optimizer,
488 n_restarts_optimizer: self.n_restarts_optimizer,
489 missing_values: self.missing_values,
490 normalize_y: self.normalize_y,
491 random_state: self.random_state,
492 })
493 }
494}
495
496impl Transform<ArrayView2<'_, Float>, Array2<Float>>
497 for GaussianProcessImputer<GaussianProcessImputerTrained>
498{
499 #[allow(non_snake_case)]
500 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
501 let X = X.mapv(|x| x);
502 let (n_samples, n_features) = X.dim();
503
504 if n_features != self.state.n_features_in_ {
505 return Err(SklearsError::InvalidInput(format!(
506 "Number of features {} does not match training features {}",
507 n_features, self.state.n_features_in_
508 )));
509 }
510
511 let mut X_imputed = X.clone();
512
513 for i in 0..n_samples {
514 for j in 0..n_features {
515 if self.is_missing(X_imputed[[i, j]]) {
516 let prediction = self.predict_gp(&X_imputed, i, j)?;
517 X_imputed[[i, j]] = prediction.mean;
518 }
519 }
520 }
521
522 Ok(X_imputed.mapv(|x| x as Float))
523 }
524}
525
526impl GaussianProcessImputer<GaussianProcessImputerTrained> {
527 fn is_missing(&self, value: f64) -> bool {
528 if self.missing_values.is_nan() {
529 value.is_nan()
530 } else {
531 (value - self.missing_values).abs() < f64::EPSILON
532 }
533 }
534
535 fn kernel_function(
536 &self,
537 x1: &ArrayView1<f64>,
538 x2: &ArrayView1<f64>,
539 length_scale: f64,
540 ) -> f64 {
541 match self.kernel.as_str() {
542 "linear" => x1.dot(x2),
543 "rbf" => {
544 let diff = (x1 - x2).mapv(|x| x * x).sum();
545 (-0.5 * diff / (length_scale * length_scale)).exp()
546 }
547 "matern32" => {
548 let r = ((x1 - x2).mapv(|x| x * x).sum().sqrt()) / length_scale;
549 (1.0 + (3.0_f64).sqrt() * r) * (-(3.0_f64).sqrt() * r).exp()
550 }
551 "matern52" => {
552 let r = ((x1 - x2).mapv(|x| x * x).sum().sqrt()) / length_scale;
553 (1.0 + (5.0_f64).sqrt() * r + 5.0 * r * r / 3.0) * (-(5.0_f64).sqrt() * r).exp()
554 }
555 _ => {
556 let diff = (x1 - x2).mapv(|x| x * x).sum();
558 (-0.5 * diff / (length_scale * length_scale)).exp()
559 }
560 }
561 }
562
563 #[allow(non_snake_case)]
565 pub fn predict_with_uncertainty(
566 &self,
567 X: &ArrayView2<'_, Float>,
568 ) -> SklResult<Vec<Vec<GPPredictionResult>>> {
569 let X = X.mapv(|x| x);
570 let (n_samples, n_features) = X.dim();
571
572 if n_features != self.state.n_features_in_ {
573 return Err(SklearsError::InvalidInput(format!(
574 "Number of features {} does not match training features {}",
575 n_features, self.state.n_features_in_
576 )));
577 }
578
579 let mut predictions = Vec::new();
580
581 for i in 0..n_samples {
582 let mut sample_predictions = Vec::new();
583 for j in 0..n_features {
584 if self.is_missing(X[[i, j]]) {
585 let prediction = self.predict_gp(&X, i, j)?;
586 sample_predictions.push(prediction);
587 } else {
588 sample_predictions.push(GPPredictionResult {
590 mean: X[[i, j]],
591 std: 0.0,
592 confidence_interval_95: (X[[i, j]], X[[i, j]]),
593 });
594 }
595 }
596 predictions.push(sample_predictions);
597 }
598
599 Ok(predictions)
600 }
601
602 #[allow(non_snake_case)]
603 fn predict_gp(
604 &self,
605 X: &Array2<f64>,
606 sample_idx: usize,
607 target_feature: usize,
608 ) -> SklResult<GPPredictionResult> {
609 let (length_scale, alpha) = self
611 .state
612 .optimized_kernel_params_
613 .get(&target_feature)
614 .ok_or_else(|| {
615 SklearsError::InvalidInput("Missing optimized parameters".to_string())
616 })?;
617
618 let mut x_feat = Array1::zeros(self.state.n_features_in_ - 1);
620 let mut col_idx = 0;
621 for j in 0..self.state.n_features_in_ {
622 if j != target_feature {
623 x_feat[col_idx] = if self.is_missing(X[[sample_idx, j]]) {
624 self.state.X_train_.column(j).mean().unwrap_or(0.0)
626 } else {
627 X[[sample_idx, j]]
628 };
629 col_idx += 1;
630 }
631 }
632
633 let X_train_feat = self.get_training_features(target_feature);
635 let alpha_vec =
636 self.state.alpha_.get(&target_feature).ok_or_else(|| {
637 SklearsError::InvalidInput("Missing alpha coefficients".to_string())
638 })?;
639 let y_train =
640 self.state.y_train_.get(&target_feature).ok_or_else(|| {
641 SklearsError::InvalidInput("Missing training targets".to_string())
642 })?;
643 let K_inv = self
644 .state
645 .K_inv_
646 .get(&target_feature)
647 .ok_or_else(|| SklearsError::InvalidInput("Missing K_inv matrix".to_string()))?;
648
649 let mut k_star = Array1::zeros(X_train_feat.nrows());
651 for i in 0..X_train_feat.nrows() {
652 k_star[i] = self.kernel_function(&x_feat.view(), &X_train_feat.row(i), *length_scale);
653 }
654
655 let mut mean = k_star.dot(alpha_vec);
657
658 let mut y_min = f64::INFINITY;
660 let mut y_max = f64::NEG_INFINITY;
661 let mut y_sum = 0.0;
662 for &val in y_train.iter() {
663 if val.is_finite() {
664 y_min = y_min.min(val);
665 y_max = y_max.max(val);
666 y_sum += val;
667 }
668 }
669 let sample_count = y_train.len() as f64;
670 let fallback_mean = if sample_count > 0.0 {
671 y_sum / sample_count
672 } else {
673 0.0
674 };
675
676 if !mean.is_finite() {
677 mean = fallback_mean;
678 } else if y_min.is_finite() && y_max.is_finite() {
679 let range = (y_max - y_min).abs().max(1.0);
680 let lower_bound = y_min - 0.5 * range;
681 let upper_bound = y_max + 0.5 * range;
682 if mean < lower_bound {
683 mean = lower_bound;
684 } else if mean > upper_bound {
685 mean = upper_bound;
686 }
687 }
688
689 let k_star_star =
691 self.kernel_function(&x_feat.view(), &x_feat.view(), *length_scale) + alpha;
692 let variance = k_star_star - k_star.dot(&K_inv.dot(&k_star));
693 let mut std = variance.max(0.0).sqrt(); if !std.is_finite() {
695 let fallback_var = if sample_count > 1.0 {
696 y_train
697 .iter()
698 .map(|&v| {
699 let diff = v - fallback_mean;
700 diff * diff
701 })
702 .sum::<f64>()
703 / (sample_count - 1.0)
704 } else {
705 0.0
706 };
707 std = fallback_var.max(0.0).sqrt();
708 }
709
710 let ci_width = 1.96 * std;
712 let confidence_interval_95 = (mean - ci_width, mean + ci_width);
713
714 Ok(GPPredictionResult {
715 mean,
716 std,
717 confidence_interval_95,
718 })
719 }
720
721 fn get_training_features(&self, target_feature: usize) -> Array2<f64> {
722 let n_train = self.state.X_train_.nrows();
723 let mut X_feat = Array2::zeros((n_train, self.state.n_features_in_ - 1));
724
725 let mut col_idx = 0;
726 for j in 0..self.state.n_features_in_ {
727 if j != target_feature {
728 for i in 0..n_train {
729 X_feat[[i, col_idx]] = self.state.X_train_[[i, j]];
730 }
731 col_idx += 1;
732 }
733 }
734
735 X_feat
736 }
737
738 pub fn log_marginal_likelihood(&self) -> &HashMap<usize, f64> {
740 &self.state.log_marginal_likelihood_
741 }
742
743 pub fn optimized_kernel_params(&self) -> &HashMap<usize, (f64, f64)> {
745 &self.state.optimized_kernel_params_
746 }
747}