sklears_gaussian_process/
automatic_kernel.rs

1//! Automatic kernel construction and selection
2//!
3//! This module provides functionality for automatically constructing and selecting
4//! appropriate kernel functions based on data characteristics and statistical tests.
5
6use crate::kernels::*;
7// SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
8use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2, Axis};
9use sklears_core::error::{Result as SklResult, SklearsError};
10
11/// Automatic kernel constructor that can intelligently choose and combine kernels
12#[derive(Debug, Clone)]
13pub struct AutomaticKernelConstructor {
14    /// Maximum number of components to consider in composite kernels
15    pub max_components: usize,
16    /// Whether to consider periodic patterns
17    pub include_periodic: bool,
18    /// Whether to consider linear trends
19    pub include_linear: bool,
20    /// Whether to consider polynomial features
21    pub include_polynomial: bool,
22    /// Minimum correlation threshold for including components
23    pub correlation_threshold: f64,
24    /// Whether to use cross-validation for kernel selection
25    pub use_cross_validation: bool,
26    /// Random state for reproducible results
27    pub random_state: Option<u64>,
28}
29
30impl Default for AutomaticKernelConstructor {
31    fn default() -> Self {
32        Self {
33            max_components: 5,
34            include_periodic: true,
35            include_linear: true,
36            include_polynomial: true,
37            correlation_threshold: 0.1,
38            use_cross_validation: true,
39            random_state: Some(42),
40        }
41    }
42}
43
44/// Result of automatic kernel construction
45#[derive(Debug, Clone)]
46pub struct KernelConstructionResult {
47    /// The best kernel found
48    pub best_kernel: Box<dyn Kernel>,
49    /// Score of the best kernel (negative log marginal likelihood)
50    pub best_score: f64,
51    /// All kernels tried with their scores
52    pub kernel_scores: Vec<(String, f64)>,
53    /// Data characteristics detected
54    pub data_characteristics: DataCharacteristics,
55}
56
57/// Data characteristics detected during analysis
58#[derive(Debug, Clone)]
59pub struct DataCharacteristics {
60    /// Dimensionality of input data
61    pub n_dimensions: usize,
62    /// Number of data points
63    pub n_samples: usize,
64    /// Presence of periodic patterns
65    pub has_periodicity: bool,
66    /// Strength of linear trend (0-1)
67    pub linear_trend_strength: f64,
68    /// Noise level estimate
69    pub noise_level: f64,
70    /// Length scale estimates for each dimension
71    pub length_scales: Array1<f64>,
72    /// Dominant frequencies (if periodic patterns detected)
73    pub dominant_frequencies: Vec<f64>,
74}
75
76impl AutomaticKernelConstructor {
77    /// Create a new automatic kernel constructor
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Set maximum number of components
83    pub fn max_components(mut self, max_components: usize) -> Self {
84        self.max_components = max_components;
85        self
86    }
87
88    /// Set whether to include periodic patterns
89    pub fn include_periodic(mut self, include_periodic: bool) -> Self {
90        self.include_periodic = include_periodic;
91        self
92    }
93
94    /// Set whether to include linear trends
95    pub fn include_linear(mut self, include_linear: bool) -> Self {
96        self.include_linear = include_linear;
97        self
98    }
99
100    /// Set correlation threshold
101    pub fn correlation_threshold(mut self, threshold: f64) -> Self {
102        self.correlation_threshold = threshold;
103        self
104    }
105
106    /// Set random state
107    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
108        self.random_state = random_state;
109        self
110    }
111
112    /// Set whether to use cross-validation for kernel evaluation
113    pub fn use_cross_validation(mut self, use_cross_validation: bool) -> Self {
114        self.use_cross_validation = use_cross_validation;
115        self
116    }
117
118    /// Automatically construct the best kernel for the given data
119    pub fn construct_kernel(
120        &self,
121        X: ArrayView2<f64>,
122        y: ArrayView1<f64>,
123    ) -> SklResult<KernelConstructionResult> {
124        // Analyze data characteristics
125        let characteristics = self.analyze_data_characteristics(&X, &y)?;
126
127        // Generate candidate kernels based on characteristics
128        let candidate_kernels = self.generate_candidate_kernels(&characteristics)?;
129
130        // Evaluate kernels and select the best one
131        let mut kernel_scores = Vec::new();
132        let mut best_kernel: Option<Box<dyn Kernel>> = None;
133        let mut best_score = f64::INFINITY;
134
135        for (name, kernel) in candidate_kernels {
136            let score = self.evaluate_kernel(&kernel, &X, &y)?;
137            kernel_scores.push((name.clone(), score));
138
139            if score < best_score {
140                best_score = score;
141                best_kernel = Some(kernel);
142            }
143        }
144
145        let best_kernel = best_kernel
146            .ok_or_else(|| SklearsError::InvalidOperation("No valid kernels found".to_string()))?;
147
148        Ok(KernelConstructionResult {
149            best_kernel,
150            best_score,
151            kernel_scores,
152            data_characteristics: characteristics,
153        })
154    }
155
156    /// Analyze characteristics of the input data
157    fn analyze_data_characteristics(
158        &self,
159        X: &ArrayView2<f64>,
160        y: &ArrayView1<f64>,
161    ) -> SklResult<DataCharacteristics> {
162        let n_samples = X.nrows();
163        let n_dimensions = X.ncols();
164
165        // Estimate noise level from residuals of simple linear fit
166        let noise_level = self.estimate_noise_level(X, y)?;
167
168        // Detect linear trend strength
169        let linear_trend_strength = self.detect_linear_trend(X, y)?;
170
171        // Estimate length scales for each dimension
172        let length_scales = self.estimate_length_scales(X)?;
173
174        // Detect periodicity
175        let (has_periodicity, dominant_frequencies) = if self.include_periodic {
176            self.detect_periodicity(X, y)?
177        } else {
178            (false, Vec::new())
179        };
180
181        Ok(DataCharacteristics {
182            n_dimensions,
183            n_samples,
184            has_periodicity,
185            linear_trend_strength,
186            noise_level,
187            length_scales,
188            dominant_frequencies,
189        })
190    }
191
192    /// Estimate noise level from data
193    fn estimate_noise_level(&self, _X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
194        // Simple approach: use variance of differences between nearby points
195        if y.len() < 2 {
196            return Ok(0.1); // Default noise level
197        }
198
199        let mut differences = Vec::new();
200        for i in 1..y.len() {
201            differences.push((y[i] - y[i - 1]).abs());
202        }
203
204        let mean_diff = differences.iter().sum::<f64>() / differences.len() as f64;
205        Ok(mean_diff.max(1e-6)) // Ensure minimum noise level
206    }
207
208    /// Detect linear trend strength
209    fn detect_linear_trend(&self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
210        if X.ncols() == 0 || y.is_empty() {
211            return Ok(0.0);
212        }
213
214        // Simple correlation with first dimension
215        let x_first = X.column(0);
216        let correlation = self.compute_correlation(&x_first, y)?;
217        Ok(correlation.abs())
218    }
219
220    /// Compute correlation between two arrays
221    fn compute_correlation(&self, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
222        if x.len() != y.len() || x.is_empty() {
223            return Ok(0.0);
224        }
225
226        let x_mean = x.mean().unwrap_or(0.0);
227        let y_mean = y.mean().unwrap_or(0.0);
228
229        let mut numerator = 0.0;
230        let mut x_var = 0.0;
231        let mut y_var = 0.0;
232
233        for i in 0..x.len() {
234            let x_diff = x[i] - x_mean;
235            let y_diff = y[i] - y_mean;
236            numerator += x_diff * y_diff;
237            x_var += x_diff * x_diff;
238            y_var += y_diff * y_diff;
239        }
240
241        let denominator = (x_var * y_var).sqrt();
242        if denominator < 1e-10 {
243            Ok(0.0)
244        } else {
245            Ok(numerator / denominator)
246        }
247    }
248
249    /// Estimate characteristic length scales for each dimension
250    fn estimate_length_scales(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
251        let mut length_scales = Array1::zeros(X.ncols());
252
253        for dim in 0..X.ncols() {
254            let column = X.column(dim);
255            let range = column.fold(f64::NEG_INFINITY, |a, &b| a.max(b))
256                - column.fold(f64::INFINITY, |a, &b| a.min(b));
257
258            // Use a fraction of the range as initial length scale estimate
259            length_scales[dim] = (range / 10.0).max(1e-3);
260        }
261
262        Ok(length_scales)
263    }
264
265    /// Detect periodic patterns in the data
266    fn detect_periodicity(
267        &self,
268        _X: &ArrayView2<f64>,
269        y: &ArrayView1<f64>,
270    ) -> SklResult<(bool, Vec<f64>)> {
271        // Simple autocorrelation-based periodicity detection
272        let mut dominant_frequencies = Vec::new();
273
274        if y.len() < 10 {
275            return Ok((false, dominant_frequencies));
276        }
277
278        // Compute autocorrelation at different lags
279        let max_lag = (y.len() / 4).min(50);
280        let mut autocorr = Vec::new();
281
282        for lag in 1..max_lag {
283            let mut correlation = 0.0;
284            let mut count = 0;
285
286            for i in lag..y.len() {
287                correlation += y[i] * y[i - lag];
288                count += 1;
289            }
290
291            if count > 0 {
292                autocorr.push(correlation / count as f64);
293            }
294        }
295
296        // Find peaks in autocorrelation
297        let threshold = 0.3; // Minimum correlation for considering periodicity
298        let mut has_periodicity = false;
299
300        for i in 1..autocorr.len() - 1 {
301            if autocorr[i] > threshold
302                && autocorr[i] > autocorr[i - 1]
303                && autocorr[i] > autocorr[i + 1]
304            {
305                has_periodicity = true;
306                // Convert lag to frequency (approximate)
307                let frequency = 2.0 * std::f64::consts::PI / (i as f64 + 1.0);
308                dominant_frequencies.push(frequency);
309            }
310        }
311
312        Ok((has_periodicity, dominant_frequencies))
313    }
314
315    /// Generate candidate kernels based on data characteristics
316    fn generate_candidate_kernels(
317        &self,
318        characteristics: &DataCharacteristics,
319    ) -> SklResult<Vec<(String, Box<dyn Kernel>)>> {
320        let mut kernels = Vec::new();
321
322        // Base RBF kernel (always include)
323        let base_length_scale = characteristics.length_scales.mean().unwrap_or(1.0);
324        kernels.push((
325            "RBF".to_string(),
326            Box::new(RBF::new(base_length_scale)) as Box<dyn Kernel>,
327        ));
328
329        // ARD RBF kernel if multi-dimensional
330        if characteristics.n_dimensions > 1 {
331            kernels.push((
332                "ARD_RBF".to_string(),
333                Box::new(crate::kernels::ARDRBF::new(
334                    characteristics.length_scales.clone(),
335                )) as Box<dyn Kernel>,
336            ));
337        }
338
339        // Matérn kernels
340        kernels.push((
341            "Matern_1_2".to_string(),
342            Box::new(Matern::new(base_length_scale, 0.5)) as Box<dyn Kernel>,
343        ));
344        kernels.push((
345            "Matern_3_2".to_string(),
346            Box::new(Matern::new(base_length_scale, 1.5)) as Box<dyn Kernel>,
347        ));
348
349        // Linear kernel if strong linear trend
350        if self.include_linear && characteristics.linear_trend_strength > 0.3 {
351            kernels.push((
352                "Linear".to_string(),
353                Box::new(Linear::new(1.0, 1.0)) as Box<dyn Kernel>,
354            ));
355
356            // RBF + Linear combination
357            let rbf = Box::new(RBF::new(base_length_scale));
358            let linear = Box::new(Linear::new(1.0, 1.0));
359            kernels.push((
360                "RBF+Linear".to_string(),
361                Box::new(crate::kernels::SumKernel::new(vec![rbf, linear])) as Box<dyn Kernel>,
362            ));
363        }
364
365        // Periodic kernels if periodicity detected
366        if self.include_periodic && characteristics.has_periodicity {
367            for &freq in &characteristics.dominant_frequencies {
368                let period = 2.0 * std::f64::consts::PI / freq;
369                kernels.push((
370                    format!("ExpSineSquared_{:.2}", period),
371                    Box::new(ExpSineSquared::new(base_length_scale, period)) as Box<dyn Kernel>,
372                ));
373
374                // RBF * ExpSineSquared combination
375                let rbf = Box::new(RBF::new(base_length_scale));
376                let periodic = Box::new(ExpSineSquared::new(base_length_scale, period));
377                kernels.push((
378                    format!("RBF*ExpSineSquared_{:.2}", period),
379                    Box::new(crate::kernels::ProductKernel::new(vec![rbf, periodic]))
380                        as Box<dyn Kernel>,
381                ));
382            }
383        }
384
385        // Rational Quadratic (good for multiple length scales)
386        kernels.push((
387            "RationalQuadratic".to_string(),
388            Box::new(RationalQuadratic::new(base_length_scale, 1.0)) as Box<dyn Kernel>,
389        ));
390
391        // Note: SpectralMixture kernel can be added when properly exported
392
393        Ok(kernels)
394    }
395
396    /// Evaluate a kernel using cross-validation or simple marginal likelihood
397    fn evaluate_kernel(
398        &self,
399        kernel: &Box<dyn Kernel>,
400        X: &ArrayView2<f64>,
401        y: &ArrayView1<f64>,
402    ) -> SklResult<f64> {
403        if self.use_cross_validation && X.nrows() > 10 {
404            self.cross_validate_kernel(kernel, X, y)
405        } else {
406            self.evaluate_marginal_likelihood(kernel, X, y)
407        }
408    }
409
410    /// Simple marginal likelihood evaluation
411    #[allow(non_snake_case)]
412    fn evaluate_marginal_likelihood(
413        &self,
414        kernel: &Box<dyn Kernel>,
415        X: &ArrayView2<f64>,
416        y: &ArrayView1<f64>,
417    ) -> SklResult<f64> {
418        // Compute kernel matrix
419        let X_owned = X.to_owned();
420        let K = kernel.compute_kernel_matrix(&X_owned, Some(&X_owned))?;
421
422        // Add noise to diagonal
423        let mut K_noisy = K;
424        let noise_var = 0.1; // Simple noise estimate
425        for i in 0..K_noisy.nrows() {
426            K_noisy[[i, i]] += noise_var;
427        }
428
429        // Compute Cholesky decomposition
430        match crate::utils::cholesky_decomposition(&K_noisy) {
431            Ok(L) => {
432                // Compute log marginal likelihood
433                let mut log_det = 0.0;
434                for i in 0..L.nrows() {
435                    log_det += L[[i, i]].ln();
436                }
437                log_det *= 2.0;
438
439                // Solve for alpha = K^(-1) * y
440                let y_owned = y.to_owned();
441                let alpha = match crate::utils::triangular_solve(&L, &y_owned) {
442                    Ok(temp) => {
443                        let L_T = L.t();
444                        crate::utils::triangular_solve(&L_T.view().to_owned(), &temp)?
445                    }
446                    Err(_) => return Ok(f64::INFINITY), // Numerical issues
447                };
448
449                let data_fit = -0.5 * y.dot(&alpha);
450                let complexity_penalty = -0.5 * log_det;
451                let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
452
453                Ok(-(data_fit + complexity_penalty + normalization))
454            }
455            Err(_) => Ok(f64::INFINITY), // Kernel matrix not positive definite
456        }
457    }
458
459    /// Cross-validation based kernel evaluation
460    #[allow(non_snake_case)]
461    fn cross_validate_kernel(
462        &self,
463        kernel: &Box<dyn Kernel>,
464        X: &ArrayView2<f64>,
465        y: &ArrayView1<f64>,
466    ) -> SklResult<f64> {
467        let n_folds = 5.min(X.nrows() / 2);
468        if n_folds < 2 {
469            return self.evaluate_marginal_likelihood(kernel, X, y);
470        }
471
472        let fold_size = X.nrows() / n_folds;
473        let mut total_score = 0.0;
474
475        for fold in 0..n_folds {
476            let start_idx = fold * fold_size;
477            let end_idx = if fold == n_folds - 1 {
478                X.nrows()
479            } else {
480                (fold + 1) * fold_size
481            };
482
483            // Create train/test splits
484            let mut train_indices = Vec::new();
485            let mut test_indices = Vec::new();
486
487            for i in 0..X.nrows() {
488                if i >= start_idx && i < end_idx {
489                    test_indices.push(i);
490                } else {
491                    train_indices.push(i);
492                }
493            }
494
495            if train_indices.is_empty() || test_indices.is_empty() {
496                continue;
497            }
498
499            // Extract train and test data
500            let X_train = X.select(Axis(0), &train_indices);
501            let y_train = y.select(Axis(0), &train_indices);
502            let _X_test = X.select(Axis(0), &test_indices);
503            let _y_test = y.select(Axis(0), &test_indices);
504
505            // Evaluate on this fold
506            let fold_score =
507                self.evaluate_marginal_likelihood(kernel, &X_train.view(), &y_train.view())?;
508
509            total_score += fold_score;
510        }
511
512        Ok(total_score / n_folds as f64)
513    }
514}
515
516#[allow(non_snake_case)]
517#[cfg(test)]
518mod tests {
519    use super::*;
520    // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
521    use scirs2_core::ndarray::{Array1, Array2};
522
523    #[test]
524    fn test_automatic_kernel_constructor_creation() {
525        let constructor = AutomaticKernelConstructor::new();
526        assert_eq!(constructor.max_components, 5);
527        assert!(constructor.include_periodic);
528        assert!(constructor.include_linear);
529    }
530
531    #[test]
532    #[allow(non_snake_case)]
533    fn test_data_characteristics_analysis() {
534        let constructor = AutomaticKernelConstructor::new();
535
536        // Create simple test data
537        let X = Array2::from_shape_vec(
538            (10, 2),
539            vec![
540                1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
541                9.0, 10.0, 10.0, 11.0,
542            ],
543        )
544        .unwrap();
545        let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
546
547        let characteristics = constructor
548            .analyze_data_characteristics(&X.view(), &y.view())
549            .unwrap();
550
551        assert_eq!(characteristics.n_dimensions, 2);
552        assert_eq!(characteristics.n_samples, 10);
553        assert!(characteristics.linear_trend_strength > 0.5);
554    }
555
556    #[test]
557    #[allow(non_snake_case)]
558    fn test_kernel_construction() {
559        let constructor = AutomaticKernelConstructor::new()
560            .max_components(3)
561            .use_cross_validation(false);
562
563        // Create simple test data
564        let X = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
565        let y = Array1::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
566
567        let result = constructor.construct_kernel(X.view(), y.view()).unwrap();
568
569        assert!(result.best_score.is_finite());
570        assert!(result.kernel_scores.len() > 0);
571    }
572
573    #[test]
574    fn test_correlation_computation() {
575        let constructor = AutomaticKernelConstructor::new();
576        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
577        let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
578
579        let correlation = constructor
580            .compute_correlation(&x.view(), &y.view())
581            .unwrap();
582        assert!((correlation - 1.0).abs() < 1e-10);
583    }
584
585    #[test]
586    #[allow(non_snake_case)]
587    fn test_length_scale_estimation() {
588        let constructor = AutomaticKernelConstructor::new();
589        let X = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 10.0, 10.0, 20.0]).unwrap();
590
591        let length_scales = constructor.estimate_length_scales(&X.view()).unwrap();
592
593        assert_eq!(length_scales.len(), 2);
594        assert!(length_scales[0] > 0.0);
595        assert!(length_scales[1] > 0.0);
596    }
597}