1use crate::kernels::Kernel;
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
10use scirs2_core::random::Rng;
12use sklears_core::error::{Result as SklResult, SklearsError};
13use sklears_core::prelude::{Estimator, Fit, Predict};
14
15#[derive(Debug, Clone)]
38pub struct SparseSpectrumGaussianProcessRegressor {
39 pub kernel: Box<dyn Kernel>,
41 pub num_spectral_points: usize,
43 pub spectral_density_threshold: f64,
45 pub selection_method: SpectralSelectionMethod,
47 pub random_state: Option<u64>,
49 pub noise_variance: f64,
51 pub optimize_spectral_points: bool,
53 pub spectral_learning_rate: f64,
55 pub max_optimization_iterations: usize,
57}
58
59#[derive(Debug, Clone, Copy)]
61pub enum SpectralSelectionMethod {
62 Random,
63 Greedy,
64 ImportanceSampling,
65 QuasiRandom,
66 Adaptive,
67}
68
69#[derive(Debug, Clone)]
71pub struct SparseSpectrumGprTrained {
72 pub config: SparseSpectrumGaussianProcessRegressor,
74 pub spectral_points: Array2<f64>,
76 pub spectral_weights: Array1<f64>,
78 pub spectral_features: Array2<f64>,
80 pub posterior_mean: Array1<f64>,
82 pub posterior_covariance: Array2<f64>,
84 pub X_train: Array2<f64>,
86 pub y_train: Array1<f64>,
88 pub spectral_density: Array1<f64>,
90 pub log_marginal_likelihood: f64,
92}
93
94#[derive(Debug, Clone)]
96pub struct SpectralApproximationInfo {
97 pub effective_rank: f64,
99 pub spectral_coverage: f64,
101 pub max_approximation_error: f64,
103 pub selected_frequencies: Array2<f64>,
105 pub spectral_densities: Array1<f64>,
107}
108
109impl Default for SparseSpectrumGaussianProcessRegressor {
110 fn default() -> Self {
111 let kernel = Box::new(crate::kernels::RBF::new(1.0));
113 Self {
114 kernel,
115 num_spectral_points: 100,
116 spectral_density_threshold: 1e-6,
117 selection_method: SpectralSelectionMethod::Adaptive,
118 random_state: Some(42),
119 noise_variance: 1e-5,
120 optimize_spectral_points: true,
121 spectral_learning_rate: 0.01,
122 max_optimization_iterations: 50,
123 }
124 }
125}
126
127impl SparseSpectrumGaussianProcessRegressor {
128 pub fn new(kernel: Box<dyn Kernel>) -> Self {
130 Self {
131 kernel,
132 ..Default::default()
133 }
134 }
135
136 pub fn num_spectral_points(mut self, num_points: usize) -> Self {
138 self.num_spectral_points = num_points;
139 self
140 }
141
142 pub fn spectral_density_threshold(mut self, threshold: f64) -> Self {
144 self.spectral_density_threshold = threshold;
145 self
146 }
147
148 pub fn selection_method(mut self, method: SpectralSelectionMethod) -> Self {
150 self.selection_method = method;
151 self
152 }
153
154 pub fn noise_variance(mut self, variance: f64) -> Self {
156 self.noise_variance = variance;
157 self
158 }
159
160 pub fn optimize_spectral_points(mut self, optimize: bool) -> Self {
162 self.optimize_spectral_points = optimize;
163 self
164 }
165
166 pub fn random_state(mut self, seed: Option<u64>) -> Self {
168 self.random_state = seed;
169 self
170 }
171
172 fn estimate_spectral_density(
174 &self,
175 X: &ArrayView2<f64>,
176 num_grid_points: usize,
177 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
178 let n_features = X.ncols();
179
180 let mut rng = if let Some(seed) = self.random_state {
183 scirs2_core::random::Random::seed(seed)
184 } else {
185 scirs2_core::random::Random::seed(42)
186 };
187
188 let mut freq_ranges = Vec::new();
190 for dim in 0..n_features {
191 let column = X.column(dim);
192 let range = column.fold(f64::NEG_INFINITY, |a, &b| a.max(b))
193 - column.fold(f64::INFINITY, |a, &b| a.min(b));
194 let max_freq = 2.0 / range.max(1e-6);
195 freq_ranges.push((-max_freq, max_freq));
196 }
197
198 let mut frequencies = Array2::zeros((num_grid_points, n_features));
200 let mut spectral_densities = Array1::zeros(num_grid_points);
201
202 for i in 0..num_grid_points {
203 for dim in 0..n_features {
204 let (min_freq, max_freq) = freq_ranges[dim];
205 frequencies[[i, dim]] = rng.gen_range(min_freq..max_freq);
206 }
207
208 spectral_densities[i] =
210 self.estimate_spectral_density_at_frequency(&frequencies.row(i).to_owned(), X)?;
211 }
212
213 Ok((frequencies, spectral_densities))
214 }
215
216 fn estimate_spectral_density_at_frequency(
218 &self,
219 frequency: &Array1<f64>,
220 X: &ArrayView2<f64>,
221 ) -> SklResult<f64> {
222 let n_samples = X.nrows().min(100); let mut density = 0.0;
227
228 for i in 0..n_samples {
229 for j in i + 1..n_samples {
230 let x_diff = &X.row(i) - &X.row(j);
231 let phase = 2.0 * std::f64::consts::PI * frequency.dot(&x_diff);
232 let kernel_value = self.kernel.kernel(&X.row(i), &X.row(j));
233 density += kernel_value * phase.cos();
234 }
235 }
236
237 let normalization = (n_samples * (n_samples - 1)) as f64 / 2.0;
238 Ok((density / normalization).abs())
239 }
240
241 fn select_spectral_points(
243 &self,
244 frequencies: &Array2<f64>,
245 spectral_densities: &Array1<f64>,
246 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
247 let mut rng = if let Some(seed) = self.random_state {
249 scirs2_core::random::Random::seed(seed)
250 } else {
251 scirs2_core::random::Random::seed(42)
252 };
253
254 match self.selection_method {
255 SpectralSelectionMethod::Random => {
256 self.random_selection(frequencies, spectral_densities, &mut rng)
257 }
258 SpectralSelectionMethod::Greedy => {
259 self.greedy_selection(frequencies, spectral_densities)
260 }
261 SpectralSelectionMethod::ImportanceSampling => {
262 self.importance_sampling_selection(frequencies, spectral_densities, &mut rng)
263 }
264 SpectralSelectionMethod::QuasiRandom => {
265 self.quasi_random_selection(frequencies, spectral_densities, &mut rng)
266 }
267 SpectralSelectionMethod::Adaptive => {
268 self.adaptive_selection(frequencies, spectral_densities, &mut rng)
269 }
270 }
271 }
272
273 fn random_selection(
275 &self,
276 frequencies: &Array2<f64>,
277 spectral_densities: &Array1<f64>,
278 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
279 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
280 let total_points = frequencies.nrows();
281 let mut selected_indices = (0..total_points).collect::<Vec<_>>();
282 for i in (1..selected_indices.len()).rev() {
284 let j = rng.gen_range(0..(i + 1));
285 selected_indices.swap(i, j);
286 }
287 selected_indices.truncate(self.num_spectral_points.min(total_points));
288
289 let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
290 let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
291
292 Ok((selected_frequencies, selected_weights))
293 }
294
295 fn greedy_selection(
297 &self,
298 frequencies: &Array2<f64>,
299 spectral_densities: &Array1<f64>,
300 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
301 let mut indices_with_densities: Vec<(usize, f64)> = spectral_densities
302 .iter()
303 .enumerate()
304 .map(|(i, &density)| (i, density))
305 .collect();
306
307 indices_with_densities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
309
310 let selected_indices: Vec<usize> = indices_with_densities
311 .into_iter()
312 .take(self.num_spectral_points.min(frequencies.nrows()))
313 .map(|(idx, _)| idx)
314 .collect();
315
316 let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
317 let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
318
319 Ok((selected_frequencies, selected_weights))
320 }
321
322 fn importance_sampling_selection(
324 &self,
325 frequencies: &Array2<f64>,
326 spectral_densities: &Array1<f64>,
327 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
328 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
329 let total_density: f64 = spectral_densities.sum();
331 if total_density <= 0.0 {
332 return self.random_selection(frequencies, spectral_densities, rng);
333 }
334
335 let probabilities: Array1<f64> = spectral_densities / total_density;
336 let mut selected_indices = Vec::new();
337
338 for _ in 0..self.num_spectral_points.min(frequencies.nrows()) {
339 let mut cumulative = 0.0;
340 let random_value: f64 = rng.gen();
341
342 for (i, &prob) in probabilities.iter().enumerate() {
343 cumulative += prob;
344 if random_value <= cumulative && !selected_indices.contains(&i) {
345 selected_indices.push(i);
346 break;
347 }
348 }
349 }
350
351 while selected_indices.len() < self.num_spectral_points.min(frequencies.nrows()) {
353 let idx = rng.gen_range(0..frequencies.nrows());
354 if !selected_indices.contains(&idx) {
355 selected_indices.push(idx);
356 }
357 }
358
359 let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
360 let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
361
362 Ok((selected_frequencies, selected_weights))
363 }
364
365 fn quasi_random_selection(
367 &self,
368 frequencies: &Array2<f64>,
369 spectral_densities: &Array1<f64>,
370 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
371 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
372 let total_points = frequencies.nrows();
374 let stride = total_points / self.num_spectral_points.max(1);
375
376 let mut selected_indices = Vec::new();
377 for i in 0..self.num_spectral_points.min(total_points) {
378 let base_idx = i * stride;
379 let jitter = rng.gen_range(0..stride.max(1));
380 let idx = (base_idx + jitter).min(total_points - 1);
381 selected_indices.push(idx);
382 }
383
384 let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
385 let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
386
387 Ok((selected_frequencies, selected_weights))
388 }
389
390 fn adaptive_selection(
392 &self,
393 frequencies: &Array2<f64>,
394 spectral_densities: &Array1<f64>,
395 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
396 ) -> SklResult<(Array2<f64>, Array1<f64>)> {
397 let greedy_fraction = 0.7;
399 let num_greedy = (self.num_spectral_points as f64 * greedy_fraction) as usize;
400 let num_random = self.num_spectral_points - num_greedy;
401
402 let (greedy_freqs, greedy_weights) = if num_greedy > 0 {
404 let mut temp_config = self.clone();
405 temp_config.num_spectral_points = num_greedy;
406 temp_config.greedy_selection(frequencies, spectral_densities)?
407 } else {
408 (Array2::zeros((0, frequencies.ncols())), Array1::zeros(0))
409 };
410
411 let (random_freqs, random_weights) = if num_random > 0 {
413 let mut temp_config = self.clone();
414 temp_config.num_spectral_points = num_random;
415 temp_config.random_selection(frequencies, spectral_densities, rng)?
416 } else {
417 (Array2::zeros((0, frequencies.ncols())), Array1::zeros(0))
418 };
419
420 let mut combined_freqs = Array2::zeros((num_greedy + num_random, frequencies.ncols()));
422 let mut combined_weights = Array1::zeros(num_greedy + num_random);
423
424 if num_greedy > 0 {
425 combined_freqs
426 .slice_mut(s![0..num_greedy, ..])
427 .assign(&greedy_freqs);
428 combined_weights
429 .slice_mut(s![0..num_greedy])
430 .assign(&greedy_weights);
431 }
432
433 if num_random > 0 {
434 combined_freqs
435 .slice_mut(s![num_greedy.., ..])
436 .assign(&random_freqs);
437 combined_weights
438 .slice_mut(s![num_greedy..])
439 .assign(&random_weights);
440 }
441
442 Ok((combined_freqs, combined_weights))
443 }
444
445 fn compute_spectral_features(
447 &self,
448 X: &ArrayView2<f64>,
449 spectral_points: &Array2<f64>,
450 spectral_weights: &Array1<f64>,
451 ) -> SklResult<Array2<f64>> {
452 let n_samples = X.nrows();
453 let n_spectral = spectral_points.nrows();
454
455 let mut features = Array2::zeros((n_samples, 2 * n_spectral));
457
458 for i in 0..n_samples {
459 for j in 0..n_spectral {
460 let phase = 2.0 * std::f64::consts::PI * spectral_points.row(j).dot(&X.row(i));
461 let weight_sqrt = spectral_weights[j].sqrt();
462
463 features[[i, 2 * j]] = weight_sqrt * phase.cos();
464 features[[i, 2 * j + 1]] = weight_sqrt * phase.sin();
465 }
466 }
467
468 Ok(features)
469 }
470
471 fn optimize_spectral_points_internal(
473 &self,
474 X: &ArrayView2<f64>,
475 y: &ArrayView1<f64>,
476 mut spectral_points: Array2<f64>,
477 spectral_weights: &Array1<f64>,
478 ) -> SklResult<Array2<f64>> {
479 if !self.optimize_spectral_points {
480 return Ok(spectral_points);
481 }
482
483 for _iteration in 0..self.max_optimization_iterations {
484 let features = self.compute_spectral_features(X, &spectral_points, spectral_weights)?;
486 let objective = self.compute_spectral_objective(&features, y)?;
487
488 let mut gradients = Array2::zeros(spectral_points.raw_dim());
490 let epsilon = 1e-6;
491
492 for i in 0..spectral_points.nrows() {
493 for j in 0..spectral_points.ncols() {
494 spectral_points[[i, j]] += epsilon;
496 let features_plus =
497 self.compute_spectral_features(X, &spectral_points, spectral_weights)?;
498 let objective_plus = self.compute_spectral_objective(&features_plus, y)?;
499
500 spectral_points[[i, j]] -= epsilon;
501 gradients[[i, j]] = (objective_plus - objective) / epsilon;
502 }
503 }
504
505 spectral_points = spectral_points - self.spectral_learning_rate * gradients;
507 }
508
509 Ok(spectral_points)
510 }
511
512 #[allow(non_snake_case)]
514 fn compute_spectral_objective(
515 &self,
516 features: &Array2<f64>,
517 y: &ArrayView1<f64>,
518 ) -> SklResult<f64> {
519 let n_features = features.ncols();
521 let n_samples = features.nrows();
522
523 let phi_t_phi = features.t().dot(features);
525 let gram_matrix = phi_t_phi + Array2::<f64>::eye(n_features) * self.noise_variance;
526
527 let L = crate::utils::cholesky_decomposition(&gram_matrix)?;
529
530 let phi_t_y = features.t().dot(y);
532 let alpha = crate::utils::triangular_solve(&L, &phi_t_y)?;
533 let L_t = L.t();
534 let mean = crate::utils::triangular_solve(&L_t.view().to_owned(), &alpha)?;
535
536 let data_fit = -0.5 * y.dot(&features.dot(&mean));
538 let mut log_det = 0.0;
539 for i in 0..L.nrows() {
540 log_det += L[[i, i]].ln();
541 }
542 let complexity_penalty = -log_det;
543 let normalization = -0.5 * n_samples as f64 * (2.0 * std::f64::consts::PI).ln();
544
545 Ok(-(data_fit + complexity_penalty + normalization))
546 }
547
548 pub fn compute_approximation_info(
550 &self,
551 spectral_points: &Array2<f64>,
552 spectral_weights: &Array1<f64>,
553 ) -> SklResult<SpectralApproximationInfo> {
554 let effective_rank =
555 spectral_weights.sum() / spectral_weights.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
556 let spectral_coverage = (spectral_weights
557 .iter()
558 .filter(|&&w| w > self.spectral_density_threshold)
559 .count() as f64)
560 / spectral_weights.len() as f64;
561
562 let total_spectral_energy = spectral_weights.sum();
564 let selected_energy = spectral_weights
565 .iter()
566 .filter(|&&w| w > self.spectral_density_threshold)
567 .sum::<f64>();
568 let max_approximation_error = 1.0 - (selected_energy / total_spectral_energy.max(1e-10));
569
570 Ok(SpectralApproximationInfo {
571 effective_rank,
572 spectral_coverage,
573 max_approximation_error,
574 selected_frequencies: spectral_points.clone(),
575 spectral_densities: spectral_weights.clone(),
576 })
577 }
578}
579
580impl Estimator for SparseSpectrumGaussianProcessRegressor {
581 type Config = SparseSpectrumGaussianProcessRegressor;
582 type Error = SklearsError;
583 type Float = f64;
584
585 fn config(&self) -> &Self::Config {
586 self
587 }
588}
589
590impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, f64>, SparseSpectrumGprTrained>
591 for SparseSpectrumGaussianProcessRegressor
592{
593 type Fitted = SparseSpectrumGprTrained;
594 #[allow(non_snake_case)]
595 fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<SparseSpectrumGprTrained> {
596 if X.nrows() != y.len() {
597 return Err(SklearsError::InvalidInput(
598 "Number of samples in X and y must match".to_string(),
599 ));
600 }
601
602 let grid_size = (self.num_spectral_points * 10).max(1000);
604 let (frequencies, spectral_densities) = self.estimate_spectral_density(X, grid_size)?;
605
606 let (mut spectral_points, spectral_weights) =
608 self.select_spectral_points(&frequencies, &spectral_densities)?;
609
610 spectral_points =
612 self.optimize_spectral_points_internal(X, y, spectral_points, &spectral_weights)?;
613
614 let spectral_features =
616 self.compute_spectral_features(X, &spectral_points, &spectral_weights)?;
617
618 let n_features = spectral_features.ncols();
620 let phi_t_phi = spectral_features.t().dot(&spectral_features);
621 let gram_matrix = phi_t_phi + Array2::<f64>::eye(n_features) * self.noise_variance;
622
623 let L = crate::utils::cholesky_decomposition(&gram_matrix)?;
625
626 let phi_t_y = spectral_features.t().dot(y);
628 let alpha = crate::utils::triangular_solve(&L, &phi_t_y)?;
629 let L_t = L.t();
630 let posterior_mean = crate::utils::triangular_solve(&L_t.view().to_owned(), &alpha)?;
631
632 let posterior_covariance = Array2::<f64>::eye(n_features) / self.noise_variance;
634
635 let data_fit = -0.5 * y.dot(&spectral_features.dot(&posterior_mean));
637 let mut log_det = 0.0;
638 for i in 0..L.nrows() {
639 log_det += L[[i, i]].ln();
640 }
641 let complexity_penalty = -log_det;
642 let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
643 let log_marginal_likelihood = data_fit + complexity_penalty + normalization;
644
645 Ok(SparseSpectrumGprTrained {
646 config: self.clone(),
647 spectral_points,
648 spectral_weights,
649 spectral_features,
650 posterior_mean,
651 posterior_covariance,
652 X_train: X.to_owned(),
653 y_train: y.to_owned(),
654 spectral_density: spectral_densities,
655 log_marginal_likelihood,
656 })
657 }
658}
659
660impl Predict<ArrayView2<'_, f64>, Array1<f64>> for SparseSpectrumGprTrained {
661 fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
662 let test_features = self.config.compute_spectral_features(
664 X,
665 &self.spectral_points,
666 &self.spectral_weights,
667 )?;
668
669 let predictions = test_features.dot(&self.posterior_mean);
671 Ok(predictions)
672 }
673}
674
675impl SparseSpectrumGprTrained {
676 pub fn predict_with_uncertainty(
678 &self,
679 X: &ArrayView2<f64>,
680 ) -> SklResult<(Array1<f64>, Array1<f64>)> {
681 let test_features = self.config.compute_spectral_features(
683 X,
684 &self.spectral_points,
685 &self.spectral_weights,
686 )?;
687
688 let predictions = test_features.dot(&self.posterior_mean);
690
691 let mut variances = Array1::zeros(X.nrows());
693 for i in 0..X.nrows() {
694 let feature_vector = test_features.row(i);
695 let variance = feature_vector.dot(&self.posterior_covariance.dot(&feature_vector))
696 + self.config.noise_variance;
697 variances[i] = variance;
698 }
699
700 Ok((predictions, variances))
701 }
702
703 pub fn approximation_info(&self) -> SklResult<SpectralApproximationInfo> {
705 self.config
706 .compute_approximation_info(&self.spectral_points, &self.spectral_weights)
707 }
708
709 pub fn log_marginal_likelihood(&self) -> f64 {
711 self.log_marginal_likelihood
712 }
713}
714
715#[allow(non_snake_case)]
716#[cfg(test)]
717mod tests {
718 use super::*;
719 use crate::kernels::RBF;
720 use scirs2_core::ndarray::{Array1, Array2};
722
723 #[test]
724 fn test_sparse_spectrum_gpr_creation() {
725 let kernel = Box::new(RBF::new(1.0));
726 let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
727 .num_spectral_points(50)
728 .spectral_density_threshold(1e-6);
729
730 assert_eq!(gpr.num_spectral_points, 50);
731 assert_eq!(gpr.spectral_density_threshold, 1e-6);
732 }
733
734 #[test]
735 #[allow(non_snake_case)]
736 fn test_spectral_feature_computation() {
737 let kernel = Box::new(RBF::new(1.0));
738 let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel);
739
740 let X = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
741 let spectral_points = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
742 let spectral_weights = Array1::from_vec(vec![1.0, 0.5]);
743
744 let features = gpr
745 .compute_spectral_features(&X.view(), &spectral_points, &spectral_weights)
746 .unwrap();
747
748 assert_eq!(features.nrows(), 3);
749 assert_eq!(features.ncols(), 4); }
751
752 #[test]
753 #[allow(non_snake_case)]
754 fn test_sparse_spectrum_fit_predict() {
755 let kernel = Box::new(RBF::new(1.0));
756 let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
757 .num_spectral_points(10)
758 .optimize_spectral_points(false); let X = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
761 let y = Array1::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
762
763 let trained = gpr.fit(&X.view(), &y.view()).unwrap();
764 let predictions = trained.predict(&X.view()).unwrap();
765
766 assert_eq!(predictions.len(), 5);
767 assert!(trained.log_marginal_likelihood().is_finite());
768 }
769
770 #[test]
771 #[allow(non_snake_case)]
772 fn test_prediction_with_uncertainty() {
773 let kernel = Box::new(RBF::new(1.0));
774 let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
775 .num_spectral_points(5)
776 .optimize_spectral_points(false);
777
778 let X = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
779 let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
780
781 let trained = gpr.fit(&X.view(), &y.view()).unwrap();
782 let (predictions, variances) = trained.predict_with_uncertainty(&X.view()).unwrap();
783
784 assert_eq!(predictions.len(), 3);
785 assert_eq!(variances.len(), 3);
786 assert!(variances.iter().all(|&v| v >= 0.0)); }
788
789 #[test]
790 #[allow(non_snake_case)]
791 fn test_spectral_selection_methods() {
792 let kernel = Box::new(RBF::new(1.0));
793
794 let methods = vec![
795 SpectralSelectionMethod::Random,
796 SpectralSelectionMethod::Greedy,
797 SpectralSelectionMethod::ImportanceSampling,
798 SpectralSelectionMethod::QuasiRandom,
799 SpectralSelectionMethod::Adaptive,
800 ];
801
802 for method in methods {
803 let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel.clone())
804 .num_spectral_points(3)
805 .selection_method(method)
806 .optimize_spectral_points(false);
807
808 let X = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
809 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
810
811 let result = gpr.fit(&X.view(), &y.view());
812 assert!(result.is_ok());
813 }
814 }
815
816 #[test]
817 #[allow(non_snake_case)]
818 fn test_approximation_info() {
819 let kernel = Box::new(RBF::new(1.0));
820 let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
821 .num_spectral_points(5)
822 .optimize_spectral_points(false);
823
824 let X = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
825 let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
826
827 let trained = gpr.fit(&X.view(), &y.view()).unwrap();
828 let info = trained.approximation_info().unwrap();
829
830 assert!(info.effective_rank > 0.0);
831 assert!(info.spectral_coverage >= 0.0 && info.spectral_coverage <= 1.0);
832 assert!(info.max_approximation_error >= 0.0);
833 }
834}