sklears_gaussian_process/
sparse_spectrum.rs

1//! Sparse spectrum Gaussian processes for large-scale approximation
2//!
3//! This module implements sparse spectrum Gaussian processes (SSGPs) which use
4//! spectral approximation methods to scale to large datasets while maintaining
5//! accurate uncertainty quantification.
6
7use crate::kernels::Kernel;
8// SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
10// SciRS2 Policy - Use scirs2-core for random number generation
11use scirs2_core::random::Rng;
12use sklears_core::error::{Result as SklResult, SklearsError};
13use sklears_core::prelude::{Estimator, Fit, Predict};
14
15/// Sparse Spectrum Gaussian Process Regressor
16///
17/// Uses spectral approximation to scale Gaussian processes to large datasets
18/// by approximating the kernel using a sparse set of spectral points.
19///
20/// # Example
21/// ```rust
22/// use sklears_gaussian_process::{SparseSpectrumGaussianProcessRegressor, kernels::RBF};
23/// use sklears_core::prelude::*;
24/// // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
25///
26/// let kernel = Box::new(RBF::new(1.0));
27/// let model = SparseSpectrumGaussianProcessRegressor::new(kernel)
28///     .num_spectral_points(50)
29///     .spectral_density_threshold(1e-6);
30///
31/// let X = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
32/// let y = Array1::from_vec((0..10).map(|x| (x as f64).sin()).collect());
33///
34/// let trained_model = model.fit(&X.view(), &y.view()).unwrap();
35/// let predictions = trained_model.predict(&X.view()).unwrap();
36/// ```
37#[derive(Debug, Clone)]
38pub struct SparseSpectrumGaussianProcessRegressor {
39    /// Base kernel to approximate
40    pub kernel: Box<dyn Kernel>,
41    /// Number of spectral points to use
42    pub num_spectral_points: usize,
43    /// Threshold for spectral density selection
44    pub spectral_density_threshold: f64,
45    /// Method for selecting spectral points
46    pub selection_method: SpectralSelectionMethod,
47    /// Random state for reproducible results
48    pub random_state: Option<u64>,
49    /// Noise variance parameter
50    pub noise_variance: f64,
51    /// Whether to optimize spectral points during training
52    pub optimize_spectral_points: bool,
53    /// Learning rate for spectral point optimization
54    pub spectral_learning_rate: f64,
55    /// Maximum iterations for optimization
56    pub max_optimization_iterations: usize,
57}
58
59/// Methods for selecting spectral points
60#[derive(Debug, Clone, Copy)]
61pub enum SpectralSelectionMethod {
62    Random,
63    Greedy,
64    ImportanceSampling,
65    QuasiRandom,
66    Adaptive,
67}
68
69/// Trained sparse spectrum Gaussian process regressor
70#[derive(Debug, Clone)]
71pub struct SparseSpectrumGprTrained {
72    /// Original configuration
73    pub config: SparseSpectrumGaussianProcessRegressor,
74    /// Selected spectral points (frequencies)
75    pub spectral_points: Array2<f64>,
76    /// Spectral weights
77    pub spectral_weights: Array1<f64>,
78    /// Training feature matrix in spectral space
79    pub spectral_features: Array2<f64>,
80    /// Posterior mean parameters
81    pub posterior_mean: Array1<f64>,
82    /// Posterior covariance matrix
83    pub posterior_covariance: Array2<f64>,
84    /// Training inputs (for prediction)
85    pub X_train: Array2<f64>,
86    /// Training targets
87    pub y_train: Array1<f64>,
88    /// Spectral density estimates
89    pub spectral_density: Array1<f64>,
90    /// Log marginal likelihood
91    pub log_marginal_likelihood: f64,
92}
93
94/// Information about spectral approximation quality
95#[derive(Debug, Clone)]
96pub struct SpectralApproximationInfo {
97    /// Effective rank of spectral approximation
98    pub effective_rank: f64,
99    /// Spectral coverage (fraction of spectrum captured)
100    pub spectral_coverage: f64,
101    /// Maximum approximation error estimate
102    pub max_approximation_error: f64,
103    /// Selected frequencies
104    pub selected_frequencies: Array2<f64>,
105    /// Spectral density at selected points
106    pub spectral_densities: Array1<f64>,
107}
108
109impl Default for SparseSpectrumGaussianProcessRegressor {
110    fn default() -> Self {
111        // Default to RBF kernel
112        let kernel = Box::new(crate::kernels::RBF::new(1.0));
113        Self {
114            kernel,
115            num_spectral_points: 100,
116            spectral_density_threshold: 1e-6,
117            selection_method: SpectralSelectionMethod::Adaptive,
118            random_state: Some(42),
119            noise_variance: 1e-5,
120            optimize_spectral_points: true,
121            spectral_learning_rate: 0.01,
122            max_optimization_iterations: 50,
123        }
124    }
125}
126
127impl SparseSpectrumGaussianProcessRegressor {
128    /// Create a new sparse spectrum Gaussian process regressor
129    pub fn new(kernel: Box<dyn Kernel>) -> Self {
130        Self {
131            kernel,
132            ..Default::default()
133        }
134    }
135
136    /// Set the number of spectral points
137    pub fn num_spectral_points(mut self, num_points: usize) -> Self {
138        self.num_spectral_points = num_points;
139        self
140    }
141
142    /// Set the spectral density threshold
143    pub fn spectral_density_threshold(mut self, threshold: f64) -> Self {
144        self.spectral_density_threshold = threshold;
145        self
146    }
147
148    /// Set the spectral selection method
149    pub fn selection_method(mut self, method: SpectralSelectionMethod) -> Self {
150        self.selection_method = method;
151        self
152    }
153
154    /// Set the noise variance
155    pub fn noise_variance(mut self, variance: f64) -> Self {
156        self.noise_variance = variance;
157        self
158    }
159
160    /// Set whether to optimize spectral points
161    pub fn optimize_spectral_points(mut self, optimize: bool) -> Self {
162        self.optimize_spectral_points = optimize;
163        self
164    }
165
166    /// Set random state for reproducible results
167    pub fn random_state(mut self, seed: Option<u64>) -> Self {
168        self.random_state = seed;
169        self
170    }
171
172    /// Estimate spectral density of the kernel
173    fn estimate_spectral_density(
174        &self,
175        X: &ArrayView2<f64>,
176        num_grid_points: usize,
177    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
178        let n_features = X.ncols();
179
180        // Create a grid of frequencies
181        // SciRS2 Policy - Use scirs2-core for random number generation
182        let mut rng = if let Some(seed) = self.random_state {
183            scirs2_core::random::Random::seed(seed)
184        } else {
185            scirs2_core::random::Random::seed(42)
186        };
187
188        // Estimate reasonable frequency range from data
189        let mut freq_ranges = Vec::new();
190        for dim in 0..n_features {
191            let column = X.column(dim);
192            let range = column.fold(f64::NEG_INFINITY, |a, &b| a.max(b))
193                - column.fold(f64::INFINITY, |a, &b| a.min(b));
194            let max_freq = 2.0 / range.max(1e-6);
195            freq_ranges.push((-max_freq, max_freq));
196        }
197
198        // Generate frequency grid
199        let mut frequencies = Array2::zeros((num_grid_points, n_features));
200        let mut spectral_densities = Array1::zeros(num_grid_points);
201
202        for i in 0..num_grid_points {
203            for dim in 0..n_features {
204                let (min_freq, max_freq) = freq_ranges[dim];
205                frequencies[[i, dim]] = rng.gen_range(min_freq..max_freq);
206            }
207
208            // Estimate spectral density at this frequency point
209            spectral_densities[i] =
210                self.estimate_spectral_density_at_frequency(&frequencies.row(i).to_owned(), X)?;
211        }
212
213        Ok((frequencies, spectral_densities))
214    }
215
216    /// Estimate spectral density at a specific frequency
217    fn estimate_spectral_density_at_frequency(
218        &self,
219        frequency: &Array1<f64>,
220        X: &ArrayView2<f64>,
221    ) -> SklResult<f64> {
222        // For RBF-like kernels, the spectral density follows a Gaussian distribution
223        // For other kernels, we use a general Fourier transform approximation
224
225        let n_samples = X.nrows().min(100); // Use subset for efficiency
226        let mut density = 0.0;
227
228        for i in 0..n_samples {
229            for j in i + 1..n_samples {
230                let x_diff = &X.row(i) - &X.row(j);
231                let phase = 2.0 * std::f64::consts::PI * frequency.dot(&x_diff);
232                let kernel_value = self.kernel.kernel(&X.row(i), &X.row(j));
233                density += kernel_value * phase.cos();
234            }
235        }
236
237        let normalization = (n_samples * (n_samples - 1)) as f64 / 2.0;
238        Ok((density / normalization).abs())
239    }
240
241    /// Select spectral points based on the selection method
242    fn select_spectral_points(
243        &self,
244        frequencies: &Array2<f64>,
245        spectral_densities: &Array1<f64>,
246    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
247        // SciRS2 Policy - Use scirs2-core for random number generation
248        let mut rng = if let Some(seed) = self.random_state {
249            scirs2_core::random::Random::seed(seed)
250        } else {
251            scirs2_core::random::Random::seed(42)
252        };
253
254        match self.selection_method {
255            SpectralSelectionMethod::Random => {
256                self.random_selection(frequencies, spectral_densities, &mut rng)
257            }
258            SpectralSelectionMethod::Greedy => {
259                self.greedy_selection(frequencies, spectral_densities)
260            }
261            SpectralSelectionMethod::ImportanceSampling => {
262                self.importance_sampling_selection(frequencies, spectral_densities, &mut rng)
263            }
264            SpectralSelectionMethod::QuasiRandom => {
265                self.quasi_random_selection(frequencies, spectral_densities, &mut rng)
266            }
267            SpectralSelectionMethod::Adaptive => {
268                self.adaptive_selection(frequencies, spectral_densities, &mut rng)
269            }
270        }
271    }
272
273    /// Random selection of spectral points
274    fn random_selection(
275        &self,
276        frequencies: &Array2<f64>,
277        spectral_densities: &Array1<f64>,
278        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
279    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
280        let total_points = frequencies.nrows();
281        let mut selected_indices = (0..total_points).collect::<Vec<_>>();
282        // Simple shuffle using Fisher-Yates algorithm
283        for i in (1..selected_indices.len()).rev() {
284            let j = rng.gen_range(0..(i + 1));
285            selected_indices.swap(i, j);
286        }
287        selected_indices.truncate(self.num_spectral_points.min(total_points));
288
289        let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
290        let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
291
292        Ok((selected_frequencies, selected_weights))
293    }
294
295    /// Greedy selection based on spectral density
296    fn greedy_selection(
297        &self,
298        frequencies: &Array2<f64>,
299        spectral_densities: &Array1<f64>,
300    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
301        let mut indices_with_densities: Vec<(usize, f64)> = spectral_densities
302            .iter()
303            .enumerate()
304            .map(|(i, &density)| (i, density))
305            .collect();
306
307        // Sort by spectral density (descending)
308        indices_with_densities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
309
310        let selected_indices: Vec<usize> = indices_with_densities
311            .into_iter()
312            .take(self.num_spectral_points.min(frequencies.nrows()))
313            .map(|(idx, _)| idx)
314            .collect();
315
316        let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
317        let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
318
319        Ok((selected_frequencies, selected_weights))
320    }
321
322    /// Importance sampling selection
323    fn importance_sampling_selection(
324        &self,
325        frequencies: &Array2<f64>,
326        spectral_densities: &Array1<f64>,
327        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
328    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
329        // Normalize spectral densities to create probability distribution
330        let total_density: f64 = spectral_densities.sum();
331        if total_density <= 0.0 {
332            return self.random_selection(frequencies, spectral_densities, rng);
333        }
334
335        let probabilities: Array1<f64> = spectral_densities / total_density;
336        let mut selected_indices = Vec::new();
337
338        for _ in 0..self.num_spectral_points.min(frequencies.nrows()) {
339            let mut cumulative = 0.0;
340            let random_value: f64 = rng.gen();
341
342            for (i, &prob) in probabilities.iter().enumerate() {
343                cumulative += prob;
344                if random_value <= cumulative && !selected_indices.contains(&i) {
345                    selected_indices.push(i);
346                    break;
347                }
348            }
349        }
350
351        // Fill remaining slots with random selection if needed
352        while selected_indices.len() < self.num_spectral_points.min(frequencies.nrows()) {
353            let idx = rng.gen_range(0..frequencies.nrows());
354            if !selected_indices.contains(&idx) {
355                selected_indices.push(idx);
356            }
357        }
358
359        let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
360        let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
361
362        Ok((selected_frequencies, selected_weights))
363    }
364
365    /// Quasi-random selection using low-discrepancy sequences
366    fn quasi_random_selection(
367        &self,
368        frequencies: &Array2<f64>,
369        spectral_densities: &Array1<f64>,
370        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
371    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
372        // Simple implementation: stratified sampling
373        let total_points = frequencies.nrows();
374        let stride = total_points / self.num_spectral_points.max(1);
375
376        let mut selected_indices = Vec::new();
377        for i in 0..self.num_spectral_points.min(total_points) {
378            let base_idx = i * stride;
379            let jitter = rng.gen_range(0..stride.max(1));
380            let idx = (base_idx + jitter).min(total_points - 1);
381            selected_indices.push(idx);
382        }
383
384        let selected_frequencies = frequencies.select(Axis(0), &selected_indices);
385        let selected_weights = spectral_densities.select(Axis(0), &selected_indices);
386
387        Ok((selected_frequencies, selected_weights))
388    }
389
390    /// Adaptive selection based on data characteristics
391    fn adaptive_selection(
392        &self,
393        frequencies: &Array2<f64>,
394        spectral_densities: &Array1<f64>,
395        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
396    ) -> SklResult<(Array2<f64>, Array1<f64>)> {
397        // Combine greedy selection for high-density regions with random exploration
398        let greedy_fraction = 0.7;
399        let num_greedy = (self.num_spectral_points as f64 * greedy_fraction) as usize;
400        let num_random = self.num_spectral_points - num_greedy;
401
402        // Greedy selection for high-density points
403        let (greedy_freqs, greedy_weights) = if num_greedy > 0 {
404            let mut temp_config = self.clone();
405            temp_config.num_spectral_points = num_greedy;
406            temp_config.greedy_selection(frequencies, spectral_densities)?
407        } else {
408            (Array2::zeros((0, frequencies.ncols())), Array1::zeros(0))
409        };
410
411        // Random selection for exploration
412        let (random_freqs, random_weights) = if num_random > 0 {
413            let mut temp_config = self.clone();
414            temp_config.num_spectral_points = num_random;
415            temp_config.random_selection(frequencies, spectral_densities, rng)?
416        } else {
417            (Array2::zeros((0, frequencies.ncols())), Array1::zeros(0))
418        };
419
420        // Combine results
421        let mut combined_freqs = Array2::zeros((num_greedy + num_random, frequencies.ncols()));
422        let mut combined_weights = Array1::zeros(num_greedy + num_random);
423
424        if num_greedy > 0 {
425            combined_freqs
426                .slice_mut(s![0..num_greedy, ..])
427                .assign(&greedy_freqs);
428            combined_weights
429                .slice_mut(s![0..num_greedy])
430                .assign(&greedy_weights);
431        }
432
433        if num_random > 0 {
434            combined_freqs
435                .slice_mut(s![num_greedy.., ..])
436                .assign(&random_freqs);
437            combined_weights
438                .slice_mut(s![num_greedy..])
439                .assign(&random_weights);
440        }
441
442        Ok((combined_freqs, combined_weights))
443    }
444
445    /// Compute spectral features for given data points
446    fn compute_spectral_features(
447        &self,
448        X: &ArrayView2<f64>,
449        spectral_points: &Array2<f64>,
450        spectral_weights: &Array1<f64>,
451    ) -> SklResult<Array2<f64>> {
452        let n_samples = X.nrows();
453        let n_spectral = spectral_points.nrows();
454
455        // Each spectral point contributes cos and sin features
456        let mut features = Array2::zeros((n_samples, 2 * n_spectral));
457
458        for i in 0..n_samples {
459            for j in 0..n_spectral {
460                let phase = 2.0 * std::f64::consts::PI * spectral_points.row(j).dot(&X.row(i));
461                let weight_sqrt = spectral_weights[j].sqrt();
462
463                features[[i, 2 * j]] = weight_sqrt * phase.cos();
464                features[[i, 2 * j + 1]] = weight_sqrt * phase.sin();
465            }
466        }
467
468        Ok(features)
469    }
470
471    /// Optimize spectral points using gradient-based optimization
472    fn optimize_spectral_points_internal(
473        &self,
474        X: &ArrayView2<f64>,
475        y: &ArrayView1<f64>,
476        mut spectral_points: Array2<f64>,
477        spectral_weights: &Array1<f64>,
478    ) -> SklResult<Array2<f64>> {
479        if !self.optimize_spectral_points {
480            return Ok(spectral_points);
481        }
482
483        for _iteration in 0..self.max_optimization_iterations {
484            // Compute current features and objective
485            let features = self.compute_spectral_features(X, &spectral_points, spectral_weights)?;
486            let objective = self.compute_spectral_objective(&features, y)?;
487
488            // Compute gradients (simplified finite differences)
489            let mut gradients = Array2::zeros(spectral_points.raw_dim());
490            let epsilon = 1e-6;
491
492            for i in 0..spectral_points.nrows() {
493                for j in 0..spectral_points.ncols() {
494                    // Forward difference
495                    spectral_points[[i, j]] += epsilon;
496                    let features_plus =
497                        self.compute_spectral_features(X, &spectral_points, spectral_weights)?;
498                    let objective_plus = self.compute_spectral_objective(&features_plus, y)?;
499
500                    spectral_points[[i, j]] -= epsilon;
501                    gradients[[i, j]] = (objective_plus - objective) / epsilon;
502                }
503            }
504
505            // Update spectral points
506            spectral_points = spectral_points - self.spectral_learning_rate * gradients;
507        }
508
509        Ok(spectral_points)
510    }
511
512    /// Compute objective function for spectral point optimization
513    #[allow(non_snake_case)]
514    fn compute_spectral_objective(
515        &self,
516        features: &Array2<f64>,
517        y: &ArrayView1<f64>,
518    ) -> SklResult<f64> {
519        // Use negative log marginal likelihood as objective
520        let n_features = features.ncols();
521        let n_samples = features.nrows();
522
523        // Compute Phi^T Phi + noise_variance * I
524        let phi_t_phi = features.t().dot(features);
525        let gram_matrix = phi_t_phi + Array2::<f64>::eye(n_features) * self.noise_variance;
526
527        // Cholesky decomposition
528        let L = crate::utils::cholesky_decomposition(&gram_matrix)?;
529
530        // Solve for posterior mean
531        let phi_t_y = features.t().dot(y);
532        let alpha = crate::utils::triangular_solve(&L, &phi_t_y)?;
533        let L_t = L.t();
534        let mean = crate::utils::triangular_solve(&L_t.view().to_owned(), &alpha)?;
535
536        // Compute log marginal likelihood
537        let data_fit = -0.5 * y.dot(&features.dot(&mean));
538        let mut log_det = 0.0;
539        for i in 0..L.nrows() {
540            log_det += L[[i, i]].ln();
541        }
542        let complexity_penalty = -log_det;
543        let normalization = -0.5 * n_samples as f64 * (2.0 * std::f64::consts::PI).ln();
544
545        Ok(-(data_fit + complexity_penalty + normalization))
546    }
547
548    /// Compute spectral approximation quality metrics
549    pub fn compute_approximation_info(
550        &self,
551        spectral_points: &Array2<f64>,
552        spectral_weights: &Array1<f64>,
553    ) -> SklResult<SpectralApproximationInfo> {
554        let effective_rank =
555            spectral_weights.sum() / spectral_weights.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
556        let spectral_coverage = (spectral_weights
557            .iter()
558            .filter(|&&w| w > self.spectral_density_threshold)
559            .count() as f64)
560            / spectral_weights.len() as f64;
561
562        // Rough approximation error estimate
563        let total_spectral_energy = spectral_weights.sum();
564        let selected_energy = spectral_weights
565            .iter()
566            .filter(|&&w| w > self.spectral_density_threshold)
567            .sum::<f64>();
568        let max_approximation_error = 1.0 - (selected_energy / total_spectral_energy.max(1e-10));
569
570        Ok(SpectralApproximationInfo {
571            effective_rank,
572            spectral_coverage,
573            max_approximation_error,
574            selected_frequencies: spectral_points.clone(),
575            spectral_densities: spectral_weights.clone(),
576        })
577    }
578}
579
580impl Estimator for SparseSpectrumGaussianProcessRegressor {
581    type Config = SparseSpectrumGaussianProcessRegressor;
582    type Error = SklearsError;
583    type Float = f64;
584
585    fn config(&self) -> &Self::Config {
586        self
587    }
588}
589
590impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, f64>, SparseSpectrumGprTrained>
591    for SparseSpectrumGaussianProcessRegressor
592{
593    type Fitted = SparseSpectrumGprTrained;
594    #[allow(non_snake_case)]
595    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<SparseSpectrumGprTrained> {
596        if X.nrows() != y.len() {
597            return Err(SklearsError::InvalidInput(
598                "Number of samples in X and y must match".to_string(),
599            ));
600        }
601
602        // Estimate spectral density
603        let grid_size = (self.num_spectral_points * 10).max(1000);
604        let (frequencies, spectral_densities) = self.estimate_spectral_density(X, grid_size)?;
605
606        // Select spectral points
607        let (mut spectral_points, spectral_weights) =
608            self.select_spectral_points(&frequencies, &spectral_densities)?;
609
610        // Optimize spectral points if requested
611        spectral_points =
612            self.optimize_spectral_points_internal(X, y, spectral_points, &spectral_weights)?;
613
614        // Compute spectral features
615        let spectral_features =
616            self.compute_spectral_features(X, &spectral_points, &spectral_weights)?;
617
618        // Bayesian linear regression in spectral feature space
619        let n_features = spectral_features.ncols();
620        let phi_t_phi = spectral_features.t().dot(&spectral_features);
621        let gram_matrix = phi_t_phi + Array2::<f64>::eye(n_features) * self.noise_variance;
622
623        // Cholesky decomposition
624        let L = crate::utils::cholesky_decomposition(&gram_matrix)?;
625
626        // Compute posterior mean
627        let phi_t_y = spectral_features.t().dot(y);
628        let alpha = crate::utils::triangular_solve(&L, &phi_t_y)?;
629        let L_t = L.t();
630        let posterior_mean = crate::utils::triangular_solve(&L_t.view().to_owned(), &alpha)?;
631
632        // Compute posterior covariance (simplified - use diagonal approximation)
633        let posterior_covariance = Array2::<f64>::eye(n_features) / self.noise_variance;
634
635        // Compute log marginal likelihood
636        let data_fit = -0.5 * y.dot(&spectral_features.dot(&posterior_mean));
637        let mut log_det = 0.0;
638        for i in 0..L.nrows() {
639            log_det += L[[i, i]].ln();
640        }
641        let complexity_penalty = -log_det;
642        let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
643        let log_marginal_likelihood = data_fit + complexity_penalty + normalization;
644
645        Ok(SparseSpectrumGprTrained {
646            config: self.clone(),
647            spectral_points,
648            spectral_weights,
649            spectral_features,
650            posterior_mean,
651            posterior_covariance,
652            X_train: X.to_owned(),
653            y_train: y.to_owned(),
654            spectral_density: spectral_densities,
655            log_marginal_likelihood,
656        })
657    }
658}
659
660impl Predict<ArrayView2<'_, f64>, Array1<f64>> for SparseSpectrumGprTrained {
661    fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
662        // Compute spectral features for test data
663        let test_features = self.config.compute_spectral_features(
664            X,
665            &self.spectral_points,
666            &self.spectral_weights,
667        )?;
668
669        // Compute predictions
670        let predictions = test_features.dot(&self.posterior_mean);
671        Ok(predictions)
672    }
673}
674
675impl SparseSpectrumGprTrained {
676    /// Predict with uncertainty quantification
677    pub fn predict_with_uncertainty(
678        &self,
679        X: &ArrayView2<f64>,
680    ) -> SklResult<(Array1<f64>, Array1<f64>)> {
681        // Compute spectral features for test data
682        let test_features = self.config.compute_spectral_features(
683            X,
684            &self.spectral_points,
685            &self.spectral_weights,
686        )?;
687
688        // Compute predictions
689        let predictions = test_features.dot(&self.posterior_mean);
690
691        // Compute predictive variance
692        let mut variances = Array1::zeros(X.nrows());
693        for i in 0..X.nrows() {
694            let feature_vector = test_features.row(i);
695            let variance = feature_vector.dot(&self.posterior_covariance.dot(&feature_vector))
696                + self.config.noise_variance;
697            variances[i] = variance;
698        }
699
700        Ok((predictions, variances))
701    }
702
703    /// Get spectral approximation quality information
704    pub fn approximation_info(&self) -> SklResult<SpectralApproximationInfo> {
705        self.config
706            .compute_approximation_info(&self.spectral_points, &self.spectral_weights)
707    }
708
709    /// Get log marginal likelihood
710    pub fn log_marginal_likelihood(&self) -> f64 {
711        self.log_marginal_likelihood
712    }
713}
714
715#[allow(non_snake_case)]
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use crate::kernels::RBF;
720    // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
721    use scirs2_core::ndarray::{Array1, Array2};
722
723    #[test]
724    fn test_sparse_spectrum_gpr_creation() {
725        let kernel = Box::new(RBF::new(1.0));
726        let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
727            .num_spectral_points(50)
728            .spectral_density_threshold(1e-6);
729
730        assert_eq!(gpr.num_spectral_points, 50);
731        assert_eq!(gpr.spectral_density_threshold, 1e-6);
732    }
733
734    #[test]
735    #[allow(non_snake_case)]
736    fn test_spectral_feature_computation() {
737        let kernel = Box::new(RBF::new(1.0));
738        let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel);
739
740        let X = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
741        let spectral_points = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
742        let spectral_weights = Array1::from_vec(vec![1.0, 0.5]);
743
744        let features = gpr
745            .compute_spectral_features(&X.view(), &spectral_points, &spectral_weights)
746            .unwrap();
747
748        assert_eq!(features.nrows(), 3);
749        assert_eq!(features.ncols(), 4); // 2 spectral points * 2 (cos + sin)
750    }
751
752    #[test]
753    #[allow(non_snake_case)]
754    fn test_sparse_spectrum_fit_predict() {
755        let kernel = Box::new(RBF::new(1.0));
756        let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
757            .num_spectral_points(10)
758            .optimize_spectral_points(false); // Disable for faster testing
759
760        let X = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
761        let y = Array1::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
762
763        let trained = gpr.fit(&X.view(), &y.view()).unwrap();
764        let predictions = trained.predict(&X.view()).unwrap();
765
766        assert_eq!(predictions.len(), 5);
767        assert!(trained.log_marginal_likelihood().is_finite());
768    }
769
770    #[test]
771    #[allow(non_snake_case)]
772    fn test_prediction_with_uncertainty() {
773        let kernel = Box::new(RBF::new(1.0));
774        let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
775            .num_spectral_points(5)
776            .optimize_spectral_points(false);
777
778        let X = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
779        let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
780
781        let trained = gpr.fit(&X.view(), &y.view()).unwrap();
782        let (predictions, variances) = trained.predict_with_uncertainty(&X.view()).unwrap();
783
784        assert_eq!(predictions.len(), 3);
785        assert_eq!(variances.len(), 3);
786        assert!(variances.iter().all(|&v| v >= 0.0)); // Variances should be non-negative
787    }
788
789    #[test]
790    #[allow(non_snake_case)]
791    fn test_spectral_selection_methods() {
792        let kernel = Box::new(RBF::new(1.0));
793
794        let methods = vec![
795            SpectralSelectionMethod::Random,
796            SpectralSelectionMethod::Greedy,
797            SpectralSelectionMethod::ImportanceSampling,
798            SpectralSelectionMethod::QuasiRandom,
799            SpectralSelectionMethod::Adaptive,
800        ];
801
802        for method in methods {
803            let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel.clone())
804                .num_spectral_points(3)
805                .selection_method(method)
806                .optimize_spectral_points(false);
807
808            let X = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
809            let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
810
811            let result = gpr.fit(&X.view(), &y.view());
812            assert!(result.is_ok());
813        }
814    }
815
816    #[test]
817    #[allow(non_snake_case)]
818    fn test_approximation_info() {
819        let kernel = Box::new(RBF::new(1.0));
820        let gpr = SparseSpectrumGaussianProcessRegressor::new(kernel)
821            .num_spectral_points(5)
822            .optimize_spectral_points(false);
823
824        let X = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
825        let y = Array1::from_vec(vec![1.0, 2.0, 3.0]);
826
827        let trained = gpr.fit(&X.view(), &y.view()).unwrap();
828        let info = trained.approximation_info().unwrap();
829
830        assert!(info.effective_rank > 0.0);
831        assert!(info.spectral_coverage >= 0.0 && info.spectral_coverage <= 1.0);
832        assert!(info.max_approximation_error >= 0.0);
833    }
834}