1use crate::kernels::Kernel;
9use crate::utils::{robust_cholesky, triangular_solve};
10use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
12use sklears_core::error::{Result as SklResult, SklearsError};
13use std::f64::consts::PI;
14
15#[derive(Debug, Clone)]
32pub struct MarginalLikelihoodOptimizer {
33 max_iter: usize,
34 tol: f64,
35 learning_rate: f64,
36 beta1: f64,
37 beta2: f64,
38 epsilon: f64,
39 line_search: bool,
40 verbose: bool,
41}
42
43#[derive(Debug, Clone)]
45pub struct OptimizationResult {
46 pub optimal_params: Array1<f64>,
48 pub optimal_log_marginal_likelihood: f64,
50 pub n_iterations: usize,
52 pub converged: bool,
54 pub lml_history: Vec<f64>,
56 pub gradient_norm_history: Vec<f64>,
58}
59
60impl MarginalLikelihoodOptimizer {
61 pub fn new() -> Self {
63 Self {
64 max_iter: 100,
65 tol: 1e-6,
66 learning_rate: 0.01,
67 beta1: 0.9,
68 beta2: 0.999,
69 epsilon: 1e-8,
70 line_search: false,
71 verbose: false,
72 }
73 }
74
75 pub fn max_iter(mut self, max_iter: usize) -> Self {
77 self.max_iter = max_iter;
78 self
79 }
80
81 pub fn tol(mut self, tol: f64) -> Self {
83 self.tol = tol;
84 self
85 }
86
87 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
89 self.learning_rate = learning_rate;
90 self
91 }
92
93 pub fn adam_params(mut self, beta1: f64, beta2: f64, epsilon: f64) -> Self {
95 self.beta1 = beta1;
96 self.beta2 = beta2;
97 self.epsilon = epsilon;
98 self
99 }
100
101 pub fn line_search(mut self, line_search: bool) -> Self {
103 self.line_search = line_search;
104 self
105 }
106
107 pub fn verbose(mut self, verbose: bool) -> Self {
109 self.verbose = verbose;
110 self
111 }
112
113 pub fn optimize(
115 &self,
116 X: &ArrayView2<f64>,
117 y: &ArrayView1<f64>,
118 kernel: &mut Box<dyn Kernel>,
119 sigma_n: f64,
120 ) -> SklResult<OptimizationResult> {
121 let n_params = kernel.get_params().len() + 1; let mut params = Array1::<f64>::zeros(n_params);
123
124 let kernel_params = kernel.get_params();
126 for (i, ¶m) in kernel_params.iter().enumerate() {
127 params[i] = param.ln(); }
129 params[n_params - 1] = sigma_n.ln(); let mut m = Array1::<f64>::zeros(n_params);
133 let mut v = Array1::<f64>::zeros(n_params);
134
135 let mut lml_history = Vec::new();
136 let mut gradient_norm_history = Vec::new();
137 let mut converged = false;
138
139 for iter in 0..self.max_iter {
140 let exp_params = params.mapv(|x| x.exp());
142 let kernel_params = exp_params.slice(s![..n_params - 1]).to_owned();
143 let sigma_n_current = exp_params[n_params - 1];
144
145 kernel.set_params(kernel_params.as_slice().unwrap())?;
146
147 let (lml, grad) =
149 self.compute_log_marginal_likelihood_and_gradients(X, y, kernel, sigma_n_current)?;
150
151 lml_history.push(lml);
152 let grad_norm = grad.dot(&grad).sqrt();
153 gradient_norm_history.push(grad_norm);
154
155 if self.verbose && iter % 10 == 0 {
156 println!(
157 "Iteration {}: LML = {:.6}, ||grad|| = {:.6}",
158 iter, lml, grad_norm
159 );
160 }
161
162 if grad_norm < self.tol {
164 converged = true;
165 if self.verbose {
166 println!("Converged at iteration {}", iter);
167 }
168 break;
169 }
170
171 let t = (iter + 1) as f64;
173 m = self.beta1 * &m + (1.0 - self.beta1) * &grad;
174 v = self.beta2 * &v + (1.0 - self.beta2) * grad.mapv(|x| x * x);
175
176 let m_hat = &m / (1.0 - self.beta1.powf(t));
177 let v_hat = &v / (1.0 - self.beta2.powf(t));
178
179 let update = &m_hat / (v_hat.mapv(|x| x.sqrt()) + self.epsilon);
180
181 if self.line_search {
182 let mut step_size = self.learning_rate;
184 for _ in 0..5 {
185 let params_new = ¶ms + step_size * &update;
186 let exp_params_new = params_new.mapv(|x| x.exp());
187 let kernel_params_new = exp_params_new.slice(s![..n_params - 1]).to_owned();
188 let sigma_n_new = exp_params_new[n_params - 1];
189
190 kernel.set_params(kernel_params_new.as_slice().unwrap())?;
191 if let Ok((lml_new, _)) = self.compute_log_marginal_likelihood_and_gradients(
192 X,
193 y,
194 kernel,
195 sigma_n_new,
196 ) {
197 if lml_new > lml {
198 params = params_new;
199 break;
200 }
201 }
202 step_size *= 0.5;
203 }
204 } else {
205 params = ¶ms + self.learning_rate * &update;
206 }
207 }
208
209 let exp_params = params.mapv(|x| x.exp());
211 let kernel_params = exp_params.slice(s![..n_params - 1]).to_owned();
212 kernel.set_params(kernel_params.as_slice().unwrap())?;
213
214 let final_lml = lml_history.last().copied().unwrap_or(f64::NEG_INFINITY);
215
216 Ok(OptimizationResult {
217 optimal_params: exp_params,
218 optimal_log_marginal_likelihood: final_lml,
219 n_iterations: lml_history.len(),
220 converged,
221 lml_history,
222 gradient_norm_history,
223 })
224 }
225
226 #[allow(non_snake_case)]
228 fn compute_log_marginal_likelihood_and_gradients(
229 &self,
230 X: &ArrayView2<f64>,
231 y: &ArrayView1<f64>,
232 kernel: &Box<dyn Kernel>,
233 sigma_n: f64,
234 ) -> SklResult<(f64, Array1<f64>)> {
235 let n = X.nrows();
236
237 let X_owned = X.to_owned();
239 let mut K = kernel.compute_kernel_matrix(&X_owned, None)?;
240
241 for i in 0..n {
243 K[[i, i]] += sigma_n * sigma_n;
244 }
245
246 let L = robust_cholesky(&K)?;
248
249 let alpha = triangular_solve(&L, &y.to_owned())?;
251
252 let log_det_K = 2.0 * L.diag().mapv(|x| x.ln()).sum();
254 let quadratic_term = alpha.dot(y);
255 let lml = -0.5 * quadratic_term - 0.5 * log_det_K - 0.5 * n as f64 * (2.0 * PI).ln();
256
257 let kernel_params = kernel.get_params();
259 let n_params = kernel_params.len() + 1; let mut gradients = Array1::<f64>::zeros(n_params);
261
262 for (i, _) in kernel_params.iter().enumerate() {
264 let grad_K = self.compute_kernel_gradient(X, kernel, i)?;
265 let grad_lml = self.compute_lml_gradient(&grad_K, &alpha, &L)?;
266 gradients[i] = grad_lml * kernel_params[i]; }
268
269 let mut grad_K_noise = Array2::<f64>::zeros((n, n));
271 for i in 0..n {
272 grad_K_noise[[i, i]] = 2.0 * sigma_n;
273 }
274 let grad_lml_noise = self.compute_lml_gradient(&grad_K_noise, &alpha, &L)?;
275 gradients[n_params - 1] = grad_lml_noise * sigma_n; Ok((lml, gradients))
278 }
279
280 #[allow(non_snake_case)]
282 fn compute_kernel_gradient(
283 &self,
284 X: &ArrayView2<f64>,
285 kernel: &Box<dyn Kernel>,
286 param_idx: usize,
287 ) -> SklResult<Array2<f64>> {
288 let params = kernel.get_params();
290 let h = 1e-8;
291
292 let mut params_plus = params.clone();
294 params_plus[param_idx] += h;
295
296 let mut kernel_plus = kernel.clone_box();
297 kernel_plus.set_params(¶ms_plus)?;
298
299 let X_owned = X.to_owned();
300 let K_plus = kernel_plus.compute_kernel_matrix(&X_owned, None)?;
301 let K = kernel.compute_kernel_matrix(&X_owned, None)?;
302
303 let grad_K = (K_plus - K) / h;
304 Ok(grad_K)
305 }
306
307 fn compute_lml_gradient(
309 &self,
310 grad_K: &Array2<f64>,
311 alpha: &Array1<f64>,
312 L: &Array2<f64>,
313 ) -> SklResult<f64> {
314 let n = L.nrows();
316 let mut K_inv = Array2::<f64>::zeros((n, n));
317
318 for i in 0..n {
320 let mut e = Array1::<f64>::zeros(n);
321 e[i] = 1.0;
322 let col = triangular_solve(L, &e)?;
323 K_inv.column_mut(i).assign(&col);
324 }
325
326 for i in 0..n {
328 let col = K_inv.column(i).to_owned();
329 let inv_col = triangular_solve(&L.t().to_owned(), &col)?;
330 K_inv.column_mut(i).assign(&inv_col);
331 }
332
333 let alpha_outer = Array2::from_shape_fn((n, n), |(i, j)| alpha[i] * alpha[j]);
335 let diff = alpha_outer - K_inv;
336 let grad = 0.5 * (&diff * grad_K).sum();
337
338 Ok(grad)
339 }
340}
341
342impl Default for MarginalLikelihoodOptimizer {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348pub fn optimize_hyperparameters(
350 X: &ArrayView2<f64>,
351 y: &ArrayView1<f64>,
352 kernel: &mut Box<dyn Kernel>,
353 sigma_n: f64,
354 max_iter: usize,
355 tol: f64,
356 verbose: bool,
357) -> SklResult<OptimizationResult> {
358 let optimizer = MarginalLikelihoodOptimizer::new()
359 .max_iter(max_iter)
360 .tol(tol)
361 .verbose(verbose);
362
363 optimizer.optimize(X, y, kernel, sigma_n)
364}
365
366#[allow(non_snake_case)]
368pub fn log_marginal_likelihood(
369 X: &ArrayView2<f64>,
370 y: &ArrayView1<f64>,
371 kernel: &Box<dyn Kernel>,
372 sigma_n: f64,
373) -> SklResult<f64> {
374 let n = X.nrows();
375
376 let X_owned = X.to_owned();
378 let mut K = kernel.compute_kernel_matrix(&X_owned, None)?;
379
380 for i in 0..n {
382 K[[i, i]] += sigma_n * sigma_n;
383 }
384
385 let L = robust_cholesky(&K)?;
387
388 let alpha = triangular_solve(&L, &y.to_owned())?;
390
391 let log_det_K = 2.0 * L.diag().mapv(|x| x.ln()).sum();
393 let quadratic_term = alpha.dot(y);
394 let lml = -0.5 * quadratic_term - 0.5 * log_det_K - 0.5 * n as f64 * (2.0 * PI).ln();
395
396 Ok(lml)
397}
398
399#[allow(non_snake_case)]
430pub fn log_marginal_likelihood_stable(
431 X: &ArrayView2<f64>,
432 y: &ArrayView1<f64>,
433 kernel: &Box<dyn Kernel>,
434 sigma_n: f64,
435) -> SklResult<f64> {
436 let n = X.nrows();
437
438 if n == 0 {
440 return Err(SklearsError::InvalidInput("Empty input array".to_string()));
441 }
442 if sigma_n <= 0.0 {
443 return Err(SklearsError::InvalidInput(
444 "Noise level must be positive".to_string(),
445 ));
446 }
447
448 let X_owned = X.to_owned();
450 let mut K = kernel.compute_kernel_matrix(&X_owned, None)?;
451
452 let sigma_n_sq = sigma_n * sigma_n;
454 for i in 0..n {
455 K[[i, i]] += sigma_n_sq;
456 if K[[i, i]] <= 0.0 {
458 return Err(SklearsError::NumericalError(
459 "Kernel matrix is not positive definite".to_string(),
460 ));
461 }
462 }
463
464 let L = robust_cholesky(&K)?;
466
467 let alpha = triangular_solve(&L, &y.to_owned())?;
469
470 let log_det_K = {
472 let log_diag_L: Array1<f64> = L.diag().mapv(|x| {
473 if x <= 0.0 {
474 return f64::NEG_INFINITY;
475 }
476 x.ln()
477 });
478
479 if log_diag_L.iter().any(|&x| !x.is_finite()) {
481 return Err(SklearsError::NumericalError(
482 "Numerical instability in Cholesky decomposition".to_string(),
483 ));
484 }
485
486 2.0 * log_diag_L.sum()
487 };
488
489 let quadratic_term = alpha.dot(y);
491 if !quadratic_term.is_finite() || quadratic_term < 0.0 {
492 return Err(SklearsError::NumericalError(
493 "Numerical instability in quadratic term".to_string(),
494 ));
495 }
496
497 let log_2pi = (2.0 * PI).ln();
499 let lml = -0.5 * (quadratic_term + log_det_K + n as f64 * log_2pi);
500
501 if !lml.is_finite() {
503 return Err(SklearsError::NumericalError(
504 "Log marginal likelihood is not finite".to_string(),
505 ));
506 }
507
508 Ok(lml)
509}
510
511pub fn cross_validate_hyperparameters(
513 X: &ArrayView2<f64>,
514 y: &ArrayView1<f64>,
515 kernel: &mut Box<dyn Kernel>,
516 sigma_n_values: &[f64],
517 n_folds: usize,
518 random_state: Option<u64>,
519) -> SklResult<(f64, f64)> {
520 let n_samples = X.nrows();
521 let fold_size = n_samples / n_folds;
522
523 let mut best_sigma_n = sigma_n_values[0];
524 let mut best_score = f64::NEG_INFINITY;
525
526 let mut indices: Vec<usize> = (0..n_samples).collect();
528 if let Some(seed) = random_state {
529 let mut rng = seed;
531 for i in (1..indices.len()).rev() {
532 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
533 let j = (rng as usize) % (i + 1);
534 indices.swap(i, j);
535 }
536 }
537
538 for &sigma_n in sigma_n_values {
539 let mut fold_scores = Vec::new();
540
541 for fold in 0..n_folds {
542 let start_idx = fold * fold_size;
543 let end_idx = if fold == n_folds - 1 {
544 n_samples
545 } else {
546 (fold + 1) * fold_size
547 };
548
549 let mut train_indices = Vec::new();
551 let mut val_indices = Vec::new();
552
553 for (i, &idx) in indices.iter().enumerate() {
554 if i >= start_idx && i < end_idx {
555 val_indices.push(idx);
556 } else {
557 train_indices.push(idx);
558 }
559 }
560
561 let lml = log_marginal_likelihood(X, y, kernel, sigma_n)?;
564 fold_scores.push(lml);
565 }
566
567 let avg_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
568 if avg_score > best_score {
569 best_score = avg_score;
570 best_sigma_n = sigma_n;
571 }
572 }
573
574 Ok((best_sigma_n, best_score))
575}
576
577#[allow(non_snake_case)]
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::kernels::RBF;
582 use approx::assert_abs_diff_eq;
583 use scirs2_core::ndarray::array;
586
587 #[test]
588 #[allow(non_snake_case)]
589 fn test_log_marginal_likelihood_stable_basic() {
590 let X = array![[1.0], [2.0], [3.0], [4.0]];
591 let y = array![1.0, 4.0, 9.0, 16.0];
592 let kernel: Box<dyn Kernel> = Box::new(RBF::new(1.0));
593
594 let lml = log_marginal_likelihood_stable(&X.view(), &y.view(), &kernel, 0.1).unwrap();
595 assert!(lml.is_finite());
596 assert!(lml < 0.0); }
598
599 #[test]
600 #[allow(non_snake_case)]
601 fn test_log_marginal_likelihood_stable_vs_standard() {
602 let X = array![[1.0], [2.0], [3.0], [4.0]];
603 let y = array![1.0, 4.0, 9.0, 16.0];
604 let kernel: Box<dyn Kernel> = Box::new(RBF::new(1.0));
605
606 let lml_stable =
607 log_marginal_likelihood_stable(&X.view(), &y.view(), &kernel, 0.1).unwrap();
608 let lml_standard = log_marginal_likelihood(&X.view(), &y.view(), &kernel, 0.1).unwrap();
609
610 assert_abs_diff_eq!(lml_stable, lml_standard, epsilon = 1e-10);
612 }
613
614 #[test]
615 #[allow(non_snake_case)]
616 fn test_log_marginal_likelihood_stable_input_validation() {
617 let X = array![[1.0], [2.0], [3.0], [4.0]];
618 let y = array![1.0, 4.0, 9.0, 16.0];
619 let kernel: Box<dyn Kernel> = Box::new(RBF::new(1.0));
620
621 let result = log_marginal_likelihood_stable(&X.view(), &y.view(), &kernel, -0.1);
623 assert!(result.is_err());
624
625 let result = log_marginal_likelihood_stable(&X.view(), &y.view(), &kernel, 0.0);
627 assert!(result.is_err());
628
629 let X_empty = Array2::<f64>::zeros((0, 1));
631 let y_empty = Array1::<f64>::zeros(0);
632 let result = log_marginal_likelihood_stable(&X_empty.view(), &y_empty.view(), &kernel, 0.1);
633 assert!(result.is_err());
634 }
635
636 #[test]
637 #[allow(non_snake_case)]
638 fn test_log_marginal_likelihood_stable_numerical_robustness() {
639 let X = array![[1e-10], [2e-10], [3e-10]];
641 let y = array![1e-10, 2e-10, 3e-10];
642 let kernel: Box<dyn Kernel> = Box::new(RBF::new(1e-20)); let result = log_marginal_likelihood_stable(&X.view(), &y.view(), &kernel, 1e-12);
645 match result {
647 Ok(lml) => {
648 assert!(lml.is_finite());
649 }
650 Err(e) => {
651 assert!(matches!(e, SklearsError::NumericalError(_)));
653 }
654 }
655 }
656
657 #[test]
658 fn test_marginal_likelihood_optimizer_creation() {
659 let optimizer = MarginalLikelihoodOptimizer::new();
660 assert_eq!(optimizer.max_iter, 100);
661 assert_eq!(optimizer.tol, 1e-6);
662 assert_eq!(optimizer.learning_rate, 0.01);
663 }
664
665 #[test]
666 fn test_marginal_likelihood_optimizer_builder() {
667 let optimizer = MarginalLikelihoodOptimizer::new()
668 .max_iter(200)
669 .tol(1e-8)
670 .learning_rate(0.001)
671 .line_search(true)
672 .verbose(true);
673
674 assert_eq!(optimizer.max_iter, 200);
675 assert_eq!(optimizer.tol, 1e-8);
676 assert_eq!(optimizer.learning_rate, 0.001);
677 assert!(optimizer.line_search);
678 assert!(optimizer.verbose);
679 }
680}