1use crate::kernels::*;
7use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2, Axis};
9use sklears_core::error::{Result as SklResult, SklearsError};
10
11#[derive(Debug, Clone)]
13pub struct AutomaticKernelConstructor {
14 pub max_components: usize,
16 pub include_periodic: bool,
18 pub include_linear: bool,
20 pub include_polynomial: bool,
22 pub correlation_threshold: f64,
24 pub use_cross_validation: bool,
26 pub random_state: Option<u64>,
28}
29
30impl Default for AutomaticKernelConstructor {
31 fn default() -> Self {
32 Self {
33 max_components: 5,
34 include_periodic: true,
35 include_linear: true,
36 include_polynomial: true,
37 correlation_threshold: 0.1,
38 use_cross_validation: true,
39 random_state: Some(42),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct KernelConstructionResult {
47 pub best_kernel: Box<dyn Kernel>,
49 pub best_score: f64,
51 pub kernel_scores: Vec<(String, f64)>,
53 pub data_characteristics: DataCharacteristics,
55}
56
57#[derive(Debug, Clone)]
59pub struct DataCharacteristics {
60 pub n_dimensions: usize,
62 pub n_samples: usize,
64 pub has_periodicity: bool,
66 pub linear_trend_strength: f64,
68 pub noise_level: f64,
70 pub length_scales: Array1<f64>,
72 pub dominant_frequencies: Vec<f64>,
74}
75
76impl AutomaticKernelConstructor {
77 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn max_components(mut self, max_components: usize) -> Self {
84 self.max_components = max_components;
85 self
86 }
87
88 pub fn include_periodic(mut self, include_periodic: bool) -> Self {
90 self.include_periodic = include_periodic;
91 self
92 }
93
94 pub fn include_linear(mut self, include_linear: bool) -> Self {
96 self.include_linear = include_linear;
97 self
98 }
99
100 pub fn correlation_threshold(mut self, threshold: f64) -> Self {
102 self.correlation_threshold = threshold;
103 self
104 }
105
106 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
108 self.random_state = random_state;
109 self
110 }
111
112 pub fn use_cross_validation(mut self, use_cross_validation: bool) -> Self {
114 self.use_cross_validation = use_cross_validation;
115 self
116 }
117
118 pub fn construct_kernel(
120 &self,
121 X: ArrayView2<f64>,
122 y: ArrayView1<f64>,
123 ) -> SklResult<KernelConstructionResult> {
124 let characteristics = self.analyze_data_characteristics(&X, &y)?;
126
127 let candidate_kernels = self.generate_candidate_kernels(&characteristics)?;
129
130 let mut kernel_scores = Vec::new();
132 let mut best_kernel: Option<Box<dyn Kernel>> = None;
133 let mut best_score = f64::INFINITY;
134
135 for (name, kernel) in candidate_kernels {
136 let score = self.evaluate_kernel(&kernel, &X, &y)?;
137 kernel_scores.push((name.clone(), score));
138
139 if score < best_score {
140 best_score = score;
141 best_kernel = Some(kernel);
142 }
143 }
144
145 let best_kernel = best_kernel
146 .ok_or_else(|| SklearsError::InvalidOperation("No valid kernels found".to_string()))?;
147
148 Ok(KernelConstructionResult {
149 best_kernel,
150 best_score,
151 kernel_scores,
152 data_characteristics: characteristics,
153 })
154 }
155
156 fn analyze_data_characteristics(
158 &self,
159 X: &ArrayView2<f64>,
160 y: &ArrayView1<f64>,
161 ) -> SklResult<DataCharacteristics> {
162 let n_samples = X.nrows();
163 let n_dimensions = X.ncols();
164
165 let noise_level = self.estimate_noise_level(X, y)?;
167
168 let linear_trend_strength = self.detect_linear_trend(X, y)?;
170
171 let length_scales = self.estimate_length_scales(X)?;
173
174 let (has_periodicity, dominant_frequencies) = if self.include_periodic {
176 self.detect_periodicity(X, y)?
177 } else {
178 (false, Vec::new())
179 };
180
181 Ok(DataCharacteristics {
182 n_dimensions,
183 n_samples,
184 has_periodicity,
185 linear_trend_strength,
186 noise_level,
187 length_scales,
188 dominant_frequencies,
189 })
190 }
191
192 fn estimate_noise_level(&self, _X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
194 if y.len() < 2 {
196 return Ok(0.1); }
198
199 let mut differences = Vec::new();
200 for i in 1..y.len() {
201 differences.push((y[i] - y[i - 1]).abs());
202 }
203
204 let mean_diff = differences.iter().sum::<f64>() / differences.len() as f64;
205 Ok(mean_diff.max(1e-6)) }
207
208 fn detect_linear_trend(&self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
210 if X.ncols() == 0 || y.is_empty() {
211 return Ok(0.0);
212 }
213
214 let x_first = X.column(0);
216 let correlation = self.compute_correlation(&x_first, y)?;
217 Ok(correlation.abs())
218 }
219
220 fn compute_correlation(&self, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
222 if x.len() != y.len() || x.is_empty() {
223 return Ok(0.0);
224 }
225
226 let x_mean = x.mean().unwrap_or(0.0);
227 let y_mean = y.mean().unwrap_or(0.0);
228
229 let mut numerator = 0.0;
230 let mut x_var = 0.0;
231 let mut y_var = 0.0;
232
233 for i in 0..x.len() {
234 let x_diff = x[i] - x_mean;
235 let y_diff = y[i] - y_mean;
236 numerator += x_diff * y_diff;
237 x_var += x_diff * x_diff;
238 y_var += y_diff * y_diff;
239 }
240
241 let denominator = (x_var * y_var).sqrt();
242 if denominator < 1e-10 {
243 Ok(0.0)
244 } else {
245 Ok(numerator / denominator)
246 }
247 }
248
249 fn estimate_length_scales(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
251 let mut length_scales = Array1::zeros(X.ncols());
252
253 for dim in 0..X.ncols() {
254 let column = X.column(dim);
255 let range = column.fold(f64::NEG_INFINITY, |a, &b| a.max(b))
256 - column.fold(f64::INFINITY, |a, &b| a.min(b));
257
258 length_scales[dim] = (range / 10.0).max(1e-3);
260 }
261
262 Ok(length_scales)
263 }
264
265 fn detect_periodicity(
267 &self,
268 _X: &ArrayView2<f64>,
269 y: &ArrayView1<f64>,
270 ) -> SklResult<(bool, Vec<f64>)> {
271 let mut dominant_frequencies = Vec::new();
273
274 if y.len() < 10 {
275 return Ok((false, dominant_frequencies));
276 }
277
278 let max_lag = (y.len() / 4).min(50);
280 let mut autocorr = Vec::new();
281
282 for lag in 1..max_lag {
283 let mut correlation = 0.0;
284 let mut count = 0;
285
286 for i in lag..y.len() {
287 correlation += y[i] * y[i - lag];
288 count += 1;
289 }
290
291 if count > 0 {
292 autocorr.push(correlation / count as f64);
293 }
294 }
295
296 let threshold = 0.3; let mut has_periodicity = false;
299
300 for i in 1..autocorr.len() - 1 {
301 if autocorr[i] > threshold
302 && autocorr[i] > autocorr[i - 1]
303 && autocorr[i] > autocorr[i + 1]
304 {
305 has_periodicity = true;
306 let frequency = 2.0 * std::f64::consts::PI / (i as f64 + 1.0);
308 dominant_frequencies.push(frequency);
309 }
310 }
311
312 Ok((has_periodicity, dominant_frequencies))
313 }
314
315 fn generate_candidate_kernels(
317 &self,
318 characteristics: &DataCharacteristics,
319 ) -> SklResult<Vec<(String, Box<dyn Kernel>)>> {
320 let mut kernels = Vec::new();
321
322 let base_length_scale = characteristics.length_scales.mean().unwrap_or(1.0);
324 kernels.push((
325 "RBF".to_string(),
326 Box::new(RBF::new(base_length_scale)) as Box<dyn Kernel>,
327 ));
328
329 if characteristics.n_dimensions > 1 {
331 kernels.push((
332 "ARD_RBF".to_string(),
333 Box::new(crate::kernels::ARDRBF::new(
334 characteristics.length_scales.clone(),
335 )) as Box<dyn Kernel>,
336 ));
337 }
338
339 kernels.push((
341 "Matern_1_2".to_string(),
342 Box::new(Matern::new(base_length_scale, 0.5)) as Box<dyn Kernel>,
343 ));
344 kernels.push((
345 "Matern_3_2".to_string(),
346 Box::new(Matern::new(base_length_scale, 1.5)) as Box<dyn Kernel>,
347 ));
348
349 if self.include_linear && characteristics.linear_trend_strength > 0.3 {
351 kernels.push((
352 "Linear".to_string(),
353 Box::new(Linear::new(1.0, 1.0)) as Box<dyn Kernel>,
354 ));
355
356 let rbf = Box::new(RBF::new(base_length_scale));
358 let linear = Box::new(Linear::new(1.0, 1.0));
359 kernels.push((
360 "RBF+Linear".to_string(),
361 Box::new(crate::kernels::SumKernel::new(vec![rbf, linear])) as Box<dyn Kernel>,
362 ));
363 }
364
365 if self.include_periodic && characteristics.has_periodicity {
367 for &freq in &characteristics.dominant_frequencies {
368 let period = 2.0 * std::f64::consts::PI / freq;
369 kernels.push((
370 format!("ExpSineSquared_{:.2}", period),
371 Box::new(ExpSineSquared::new(base_length_scale, period)) as Box<dyn Kernel>,
372 ));
373
374 let rbf = Box::new(RBF::new(base_length_scale));
376 let periodic = Box::new(ExpSineSquared::new(base_length_scale, period));
377 kernels.push((
378 format!("RBF*ExpSineSquared_{:.2}", period),
379 Box::new(crate::kernels::ProductKernel::new(vec![rbf, periodic]))
380 as Box<dyn Kernel>,
381 ));
382 }
383 }
384
385 kernels.push((
387 "RationalQuadratic".to_string(),
388 Box::new(RationalQuadratic::new(base_length_scale, 1.0)) as Box<dyn Kernel>,
389 ));
390
391 Ok(kernels)
394 }
395
396 fn evaluate_kernel(
398 &self,
399 kernel: &Box<dyn Kernel>,
400 X: &ArrayView2<f64>,
401 y: &ArrayView1<f64>,
402 ) -> SklResult<f64> {
403 if self.use_cross_validation && X.nrows() > 10 {
404 self.cross_validate_kernel(kernel, X, y)
405 } else {
406 self.evaluate_marginal_likelihood(kernel, X, y)
407 }
408 }
409
410 #[allow(non_snake_case)]
412 fn evaluate_marginal_likelihood(
413 &self,
414 kernel: &Box<dyn Kernel>,
415 X: &ArrayView2<f64>,
416 y: &ArrayView1<f64>,
417 ) -> SklResult<f64> {
418 let X_owned = X.to_owned();
420 let K = kernel.compute_kernel_matrix(&X_owned, Some(&X_owned))?;
421
422 let mut K_noisy = K;
424 let noise_var = 0.1; for i in 0..K_noisy.nrows() {
426 K_noisy[[i, i]] += noise_var;
427 }
428
429 match crate::utils::cholesky_decomposition(&K_noisy) {
431 Ok(L) => {
432 let mut log_det = 0.0;
434 for i in 0..L.nrows() {
435 log_det += L[[i, i]].ln();
436 }
437 log_det *= 2.0;
438
439 let y_owned = y.to_owned();
441 let alpha = match crate::utils::triangular_solve(&L, &y_owned) {
442 Ok(temp) => {
443 let L_T = L.t();
444 crate::utils::triangular_solve(&L_T.view().to_owned(), &temp)?
445 }
446 Err(_) => return Ok(f64::INFINITY), };
448
449 let data_fit = -0.5 * y.dot(&alpha);
450 let complexity_penalty = -0.5 * log_det;
451 let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
452
453 Ok(-(data_fit + complexity_penalty + normalization))
454 }
455 Err(_) => Ok(f64::INFINITY), }
457 }
458
459 #[allow(non_snake_case)]
461 fn cross_validate_kernel(
462 &self,
463 kernel: &Box<dyn Kernel>,
464 X: &ArrayView2<f64>,
465 y: &ArrayView1<f64>,
466 ) -> SklResult<f64> {
467 let n_folds = 5.min(X.nrows() / 2);
468 if n_folds < 2 {
469 return self.evaluate_marginal_likelihood(kernel, X, y);
470 }
471
472 let fold_size = X.nrows() / n_folds;
473 let mut total_score = 0.0;
474
475 for fold in 0..n_folds {
476 let start_idx = fold * fold_size;
477 let end_idx = if fold == n_folds - 1 {
478 X.nrows()
479 } else {
480 (fold + 1) * fold_size
481 };
482
483 let mut train_indices = Vec::new();
485 let mut test_indices = Vec::new();
486
487 for i in 0..X.nrows() {
488 if i >= start_idx && i < end_idx {
489 test_indices.push(i);
490 } else {
491 train_indices.push(i);
492 }
493 }
494
495 if train_indices.is_empty() || test_indices.is_empty() {
496 continue;
497 }
498
499 let X_train = X.select(Axis(0), &train_indices);
501 let y_train = y.select(Axis(0), &train_indices);
502 let _X_test = X.select(Axis(0), &test_indices);
503 let _y_test = y.select(Axis(0), &test_indices);
504
505 let fold_score =
507 self.evaluate_marginal_likelihood(kernel, &X_train.view(), &y_train.view())?;
508
509 total_score += fold_score;
510 }
511
512 Ok(total_score / n_folds as f64)
513 }
514}
515
516#[allow(non_snake_case)]
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use scirs2_core::ndarray::{Array1, Array2};
522
523 #[test]
524 fn test_automatic_kernel_constructor_creation() {
525 let constructor = AutomaticKernelConstructor::new();
526 assert_eq!(constructor.max_components, 5);
527 assert!(constructor.include_periodic);
528 assert!(constructor.include_linear);
529 }
530
531 #[test]
532 #[allow(non_snake_case)]
533 fn test_data_characteristics_analysis() {
534 let constructor = AutomaticKernelConstructor::new();
535
536 let X = Array2::from_shape_vec(
538 (10, 2),
539 vec![
540 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
541 9.0, 10.0, 10.0, 11.0,
542 ],
543 )
544 .unwrap();
545 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
546
547 let characteristics = constructor
548 .analyze_data_characteristics(&X.view(), &y.view())
549 .unwrap();
550
551 assert_eq!(characteristics.n_dimensions, 2);
552 assert_eq!(characteristics.n_samples, 10);
553 assert!(characteristics.linear_trend_strength > 0.5);
554 }
555
556 #[test]
557 #[allow(non_snake_case)]
558 fn test_kernel_construction() {
559 let constructor = AutomaticKernelConstructor::new()
560 .max_components(3)
561 .use_cross_validation(false);
562
563 let X = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
565 let y = Array1::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
566
567 let result = constructor.construct_kernel(X.view(), y.view()).unwrap();
568
569 assert!(result.best_score.is_finite());
570 assert!(result.kernel_scores.len() > 0);
571 }
572
573 #[test]
574 fn test_correlation_computation() {
575 let constructor = AutomaticKernelConstructor::new();
576 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
577 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
578
579 let correlation = constructor
580 .compute_correlation(&x.view(), &y.view())
581 .unwrap();
582 assert!((correlation - 1.0).abs() < 1e-10);
583 }
584
585 #[test]
586 #[allow(non_snake_case)]
587 fn test_length_scale_estimation() {
588 let constructor = AutomaticKernelConstructor::new();
589 let X = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 5.0, 10.0, 10.0, 20.0]).unwrap();
590
591 let length_scales = constructor.estimate_length_scales(&X.view()).unwrap();
592
593 assert_eq!(length_scales.len(), 2);
594 assert!(length_scales[0] > 0.0);
595 assert!(length_scales[1] > 0.0);
596 }
597}