sklears_gaussian_process/
kernel_structure_learning.rs

1//! Advanced kernel structure learning using grammar-based search
2//!
3//! This module implements sophisticated kernel structure learning algorithms that can
4//! automatically discover complex kernel compositions using grammar-based search,
5//! statistical tests, and structure optimization.
6
7use crate::kernels::*;
8// SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
9use scirs2_core::ndarray::{ArrayView1, ArrayView2};
10// SciRS2 Policy - Use scirs2-core for random number generation
11use scirs2_core::random::Rng;
12use sklears_core::error::{Result as SklResult, SklearsError};
13
14/// Grammar-based kernel structure learning
15#[derive(Debug, Clone)]
16pub struct KernelStructureLearner {
17    /// Maximum depth of kernel expressions
18    pub max_depth: usize,
19    /// Maximum number of iterations for structure search
20    pub max_iterations: usize,
21    /// Probability of adding new components
22    pub expansion_probability: f64,
23    /// Probability of simplifying structures
24    pub simplification_probability: f64,
25    /// Minimum improvement threshold for accepting new structures
26    pub improvement_threshold: f64,
27    /// Whether to use Bayesian information criterion for model selection
28    pub use_bic: bool,
29    /// Random state for reproducible results
30    pub random_state: Option<u64>,
31    /// Search strategy
32    pub search_strategy: SearchStrategy,
33}
34
35/// Search strategies for kernel structure learning
36#[derive(Debug, Clone, Copy)]
37pub enum SearchStrategy {
38    Greedy,
39    Beam {
40        beam_width: usize,
41    },
42    /// Genetic algorithm - evolve population of kernel structures
43    Genetic {
44        population_size: usize,
45    },
46    /// Simulated annealing - accept worse solutions with decreasing probability
47    SimulatedAnnealing {
48        initial_temperature: f64,
49    },
50}
51
52/// Kernel grammar production rules
53#[derive(Debug, Clone)]
54pub enum KernelGrammar {
55    /// Terminal symbols (base kernels)
56    Terminal(TerminalKernel),
57    /// Non-terminal symbols (operations)
58    NonTerminal(NonTerminalOperation),
59}
60
61/// Base kernel types in the grammar
62#[derive(Debug, Clone)]
63pub enum TerminalKernel {
64    /// RBF
65    RBF { length_scale: f64 },
66    /// Linear
67    Linear { sigma_0: f64, sigma_1: f64 },
68    /// Periodic
69    Periodic { length_scale: f64, period: f64 },
70    /// Matern
71    Matern { length_scale: f64, nu: f64 },
72    /// RationalQuadratic
73    RationalQuadratic { length_scale: f64, alpha: f64 },
74    /// White
75    White { noise_level: f64 },
76    /// Constant
77    Constant { constant_value: f64 },
78}
79
80/// Operations for combining kernels
81#[derive(Debug, Clone)]
82pub enum NonTerminalOperation {
83    /// Sum
84    Sum {
85        left: Box<KernelGrammar>,
86        right: Box<KernelGrammar>,
87    },
88    /// Product
89    Product {
90        left: Box<KernelGrammar>,
91        right: Box<KernelGrammar>,
92    },
93    /// Power
94    Power {
95        base: Box<KernelGrammar>,
96        exponent: f64,
97    },
98    /// Scale
99    Scale {
100        kernel: Box<KernelGrammar>,
101        scale: f64,
102    },
103}
104
105/// Result of kernel structure learning
106#[derive(Debug, Clone)]
107pub struct StructureLearningResult {
108    /// The best kernel structure found
109    pub best_kernel: Box<dyn Kernel>,
110    /// Grammar expression of the best kernel
111    pub best_expression: KernelGrammar,
112    /// Score of the best kernel
113    pub best_score: f64,
114    /// All structures explored with their scores
115    pub exploration_history: Vec<(String, f64)>,
116    /// Convergence information
117    pub convergence_info: ConvergenceInfo,
118}
119
120/// Information about the learning convergence
121#[derive(Debug, Clone)]
122pub struct ConvergenceInfo {
123    /// Number of iterations completed
124    pub iterations: usize,
125    /// Score history over iterations
126    pub score_history: Vec<f64>,
127    /// Whether the algorithm converged
128    pub converged: bool,
129    /// Final temperature (for simulated annealing)
130    pub final_temperature: Option<f64>,
131}
132
133impl Default for KernelStructureLearner {
134    fn default() -> Self {
135        Self {
136            max_depth: 4,
137            max_iterations: 100,
138            expansion_probability: 0.7,
139            simplification_probability: 0.3,
140            improvement_threshold: 0.01,
141            use_bic: true,
142            random_state: Some(42),
143            search_strategy: SearchStrategy::Greedy,
144        }
145    }
146}
147
148impl KernelStructureLearner {
149    /// Create a new kernel structure learner
150    pub fn new() -> Self {
151        Self::default()
152    }
153
154    /// Set maximum depth of kernel expressions
155    pub fn max_depth(mut self, depth: usize) -> Self {
156        self.max_depth = depth;
157        self
158    }
159
160    /// Set maximum number of iterations
161    pub fn max_iterations(mut self, iterations: usize) -> Self {
162        self.max_iterations = iterations;
163        self
164    }
165
166    /// Set search strategy
167    pub fn search_strategy(mut self, strategy: SearchStrategy) -> Self {
168        self.search_strategy = strategy;
169        self
170    }
171
172    /// Set random state for reproducible results
173    pub fn random_state(mut self, seed: Option<u64>) -> Self {
174        self.random_state = seed;
175        self
176    }
177
178    /// Learn kernel structure from data
179    pub fn learn_structure(
180        &self,
181        X: ArrayView2<f64>,
182        y: ArrayView1<f64>,
183    ) -> SklResult<StructureLearningResult> {
184        // SciRS2 Policy - Use scirs2-core for random number generation
185        let mut rng = if let Some(seed) = self.random_state {
186            scirs2_core::random::Random::seed(seed)
187        } else {
188            scirs2_core::random::Random::seed(42)
189        };
190
191        match self.search_strategy {
192            SearchStrategy::Greedy => self.greedy_search(X, y, &mut rng),
193            SearchStrategy::Beam { beam_width } => self.beam_search(X, y, beam_width, &mut rng),
194            SearchStrategy::Genetic { population_size } => {
195                self.genetic_search(X, y, population_size, &mut rng)
196            }
197            SearchStrategy::SimulatedAnnealing {
198                initial_temperature,
199            } => self.simulated_annealing_search(X, y, initial_temperature, &mut rng),
200        }
201    }
202
203    /// Greedy search for kernel structure
204    fn greedy_search(
205        &self,
206        X: ArrayView2<f64>,
207        y: ArrayView1<f64>,
208        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
209    ) -> SklResult<StructureLearningResult> {
210        let mut exploration_history = Vec::new();
211        let mut score_history = Vec::new();
212
213        // Start with simple RBF kernel
214        let mut current_expression =
215            KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
216        let mut current_kernel = self.expression_to_kernel(&current_expression)?;
217        let mut current_score = self.evaluate_kernel(&current_kernel, &X, &y)?;
218
219        exploration_history.push(("RBF".to_string(), current_score));
220        score_history.push(current_score);
221
222        let mut best_expression = current_expression.clone();
223        let mut best_kernel = current_kernel.clone();
224        let mut best_score = current_score;
225
226        for _iteration in 0..self.max_iterations {
227            // Generate candidate structures
228            let candidates = self.generate_candidates(&current_expression, rng)?;
229
230            let mut found_improvement = false;
231
232            for candidate in candidates {
233                if self.expression_depth(&candidate) > self.max_depth {
234                    continue;
235                }
236
237                let kernel = self.expression_to_kernel(&candidate)?;
238                let score = self.evaluate_kernel(&kernel, &X, &y)?;
239
240                let expression_str = self.expression_to_string(&candidate);
241                exploration_history.push((expression_str, score));
242
243                if score < current_score - self.improvement_threshold {
244                    current_expression = candidate;
245                    current_kernel = kernel;
246                    current_score = score;
247                    found_improvement = true;
248
249                    if score < best_score {
250                        best_expression = current_expression.clone();
251                        best_kernel = current_kernel.clone();
252                        best_score = score;
253                    }
254                    break;
255                }
256            }
257
258            score_history.push(current_score);
259
260            if !found_improvement {
261                break;
262            }
263        }
264
265        Ok(StructureLearningResult {
266            best_kernel,
267            best_expression,
268            best_score,
269            exploration_history,
270            convergence_info: ConvergenceInfo {
271                iterations: score_history.len(),
272                score_history,
273                converged: true,
274                final_temperature: None,
275            },
276        })
277    }
278
279    /// Beam search for kernel structure
280    fn beam_search(
281        &self,
282        X: ArrayView2<f64>,
283        y: ArrayView1<f64>,
284        beam_width: usize,
285        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
286    ) -> SklResult<StructureLearningResult> {
287        let mut exploration_history = Vec::new();
288        let mut score_history = Vec::new();
289
290        // Initialize beam with base kernels
291        let mut beam: Vec<(KernelGrammar, f64)> = Vec::new();
292        let initial_kernels = self.generate_initial_kernels()?;
293
294        if initial_kernels.is_empty() {
295            return Err(SklearsError::InvalidOperation(
296                "No initial kernels generated".to_string(),
297            ));
298        }
299
300        for kernel_expr in initial_kernels {
301            let kernel = self.expression_to_kernel(&kernel_expr)?;
302            let score = self.evaluate_kernel(&kernel, &X, &y)?;
303            beam.push((kernel_expr.clone(), score));
304
305            let expr_str = self.expression_to_string(&kernel_expr);
306            exploration_history.push((expr_str, score));
307        }
308
309        // Sort beam by score (lower is better)
310        beam.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
311        beam.truncate(beam_width);
312
313        if beam.is_empty() {
314            return Err(SklearsError::InvalidOperation(
315                "Beam is empty after initialization".to_string(),
316            ));
317        }
318
319        let mut best_score = beam[0].1;
320        score_history.push(best_score);
321
322        for _iteration in 0..self.max_iterations {
323            let mut new_beam = Vec::new();
324
325            // Expand each beam element
326            for (expression, _) in &beam {
327                let candidates = self.generate_candidates(expression, rng)?;
328
329                for candidate in candidates {
330                    if self.expression_depth(&candidate) > self.max_depth {
331                        continue;
332                    }
333
334                    let kernel = self.expression_to_kernel(&candidate)?;
335                    let score = self.evaluate_kernel(&kernel, &X, &y)?;
336
337                    new_beam.push((candidate.clone(), score));
338
339                    let expr_str = self.expression_to_string(&candidate);
340                    exploration_history.push((expr_str, score));
341                }
342            }
343
344            // Merge and sort
345            new_beam.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
346            new_beam.truncate(beam_width);
347
348            // Update beam
349            beam = new_beam;
350
351            if beam.is_empty() {
352                break;
353            }
354
355            let current_best = beam[0].1;
356            score_history.push(current_best);
357
358            if (best_score - current_best).abs() < self.improvement_threshold {
359                break;
360            }
361
362            best_score = current_best;
363        }
364
365        if beam.is_empty() {
366            // Return a default RBF kernel if no beam elements remain
367            let default_expression =
368                KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
369            let default_kernel = self.expression_to_kernel(&default_expression)?;
370            return Ok(StructureLearningResult {
371                best_kernel: default_kernel,
372                best_expression: default_expression,
373                best_score: f64::INFINITY,
374                exploration_history,
375                convergence_info: ConvergenceInfo {
376                    iterations: score_history.len(),
377                    score_history,
378                    converged: false,
379                    final_temperature: None,
380                },
381            });
382        }
383
384        let best_expression = beam[0].0.clone();
385        let best_kernel = self.expression_to_kernel(&best_expression)?;
386
387        Ok(StructureLearningResult {
388            best_kernel,
389            best_expression,
390            best_score: beam[0].1,
391            exploration_history,
392            convergence_info: ConvergenceInfo {
393                iterations: score_history.len(),
394                score_history,
395                converged: true,
396                final_temperature: None,
397            },
398        })
399    }
400
401    /// Genetic algorithm search for kernel structure
402    fn genetic_search(
403        &self,
404        X: ArrayView2<f64>,
405        y: ArrayView1<f64>,
406        population_size: usize,
407        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
408    ) -> SklResult<StructureLearningResult> {
409        let mut exploration_history = Vec::new();
410        let mut score_history = Vec::new();
411
412        // Initialize population
413        let mut population: Vec<(KernelGrammar, f64)> = Vec::new();
414        let initial_kernels = self.generate_initial_kernels()?;
415
416        for _ in 0..population_size {
417            let idx = rng.gen_range(0..initial_kernels.len());
418            let kernel_expr = initial_kernels[idx].clone();
419            let kernel = self.expression_to_kernel(&kernel_expr)?;
420            let score = self.evaluate_kernel(&kernel, &X, &y)?;
421            population.push((kernel_expr, score));
422        }
423
424        for _generation in 0..self.max_iterations {
425            // Selection: tournament selection
426            let mut new_population = Vec::new();
427
428            for _ in 0..population_size {
429                let parent1 = self.tournament_selection(&population, rng);
430                let parent2 = self.tournament_selection(&population, rng);
431
432                // Crossover
433                let child = self.crossover(&parent1.0, &parent2.0, rng)?;
434
435                // Mutation
436                let mutated_child = self.mutate(&child, rng)?;
437
438                if self.expression_depth(&mutated_child) <= self.max_depth {
439                    let kernel = self.expression_to_kernel(&mutated_child)?;
440                    let score = self.evaluate_kernel(&kernel, &X, &y)?;
441                    new_population.push((mutated_child.clone(), score));
442
443                    let expr_str = self.expression_to_string(&mutated_child);
444                    exploration_history.push((expr_str, score));
445                }
446            }
447
448            // Replace population
449            population = new_population;
450            population.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
451            population.truncate(population_size);
452
453            let best_score = population[0].1;
454            score_history.push(best_score);
455        }
456
457        let best_expression = population[0].0.clone();
458        let best_kernel = self.expression_to_kernel(&best_expression)?;
459
460        Ok(StructureLearningResult {
461            best_kernel,
462            best_expression,
463            best_score: population[0].1,
464            exploration_history,
465            convergence_info: ConvergenceInfo {
466                iterations: score_history.len(),
467                score_history,
468                converged: true,
469                final_temperature: None,
470            },
471        })
472    }
473
474    /// Simulated annealing search for kernel structure
475    fn simulated_annealing_search(
476        &self,
477        X: ArrayView2<f64>,
478        y: ArrayView1<f64>,
479        initial_temperature: f64,
480        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
481    ) -> SklResult<StructureLearningResult> {
482        let mut exploration_history = Vec::new();
483        let mut score_history = Vec::new();
484
485        // Start with simple RBF kernel
486        let mut current_expression =
487            KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
488        let mut current_kernel = self.expression_to_kernel(&current_expression)?;
489        let mut current_score = self.evaluate_kernel(&current_kernel, &X, &y)?;
490
491        let mut best_expression = current_expression.clone();
492        let mut best_kernel = current_kernel.clone();
493        let mut best_score = current_score;
494
495        let mut temperature = initial_temperature;
496        let cooling_rate = 0.95;
497
498        for _iteration in 0..self.max_iterations {
499            // Generate a neighbor
500            let candidates = self.generate_candidates(&current_expression, rng)?;
501            if candidates.is_empty() {
502                break;
503            }
504
505            let idx = rng.gen_range(0..candidates.len());
506            let candidate = &candidates[idx];
507            if self.expression_depth(candidate) > self.max_depth {
508                continue;
509            }
510
511            let kernel = self.expression_to_kernel(candidate)?;
512            let score = self.evaluate_kernel(&kernel, &X, &y)?;
513
514            let expr_str = self.expression_to_string(candidate);
515            exploration_history.push((expr_str, score));
516
517            // Accept or reject based on Metropolis criterion
518            let delta = score - current_score;
519            if delta < 0.0 || rng.gen::<f64>() < (-delta / temperature).exp() {
520                current_expression = candidate.clone();
521                current_kernel = kernel;
522                current_score = score;
523
524                if score < best_score {
525                    best_expression = current_expression.clone();
526                    best_kernel = current_kernel.clone();
527                    best_score = score;
528                }
529            }
530
531            score_history.push(current_score);
532            temperature *= cooling_rate;
533        }
534
535        Ok(StructureLearningResult {
536            best_kernel,
537            best_expression,
538            best_score,
539            exploration_history,
540            convergence_info: ConvergenceInfo {
541                iterations: score_history.len(),
542                score_history,
543                converged: true,
544                final_temperature: Some(temperature),
545            },
546        })
547    }
548
549    /// Generate initial base kernels
550    fn generate_initial_kernels(&self) -> SklResult<Vec<KernelGrammar>> {
551        Ok(vec![
552            KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 }),
553            KernelGrammar::Terminal(TerminalKernel::Linear {
554                sigma_0: 1.0,
555                sigma_1: 1.0,
556            }),
557            KernelGrammar::Terminal(TerminalKernel::Matern {
558                length_scale: 1.0,
559                nu: 1.5,
560            }),
561            KernelGrammar::Terminal(TerminalKernel::RationalQuadratic {
562                length_scale: 1.0,
563                alpha: 1.0,
564            }),
565            KernelGrammar::Terminal(TerminalKernel::Periodic {
566                length_scale: 1.0,
567                period: 1.0,
568            }),
569            KernelGrammar::Terminal(TerminalKernel::White { noise_level: 0.1 }),
570            KernelGrammar::Terminal(TerminalKernel::Constant {
571                constant_value: 1.0,
572            }),
573        ])
574    }
575
576    /// Generate candidate structures from current expression
577    fn generate_candidates(
578        &self,
579        expression: &KernelGrammar,
580        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
581    ) -> SklResult<Vec<KernelGrammar>> {
582        let mut candidates = Vec::new();
583
584        // Add a new component with sum
585        if rng.gen::<f64>() < self.expansion_probability {
586            let new_kernels = self.generate_initial_kernels()?;
587            for new_kernel in new_kernels {
588                candidates.push(KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
589                    left: Box::new(expression.clone()),
590                    right: Box::new(new_kernel),
591                }));
592            }
593        }
594
595        // Add a new component with product
596        if rng.gen::<f64>() < self.expansion_probability {
597            let new_kernels = self.generate_initial_kernels()?;
598            for new_kernel in new_kernels {
599                candidates.push(KernelGrammar::NonTerminal(NonTerminalOperation::Product {
600                    left: Box::new(expression.clone()),
601                    right: Box::new(new_kernel),
602                }));
603            }
604        }
605
606        // Scale the current kernel
607        if rng.gen::<f64>() < 0.3 {
608            let scale = rng.gen_range(0.1..10.0);
609            candidates.push(KernelGrammar::NonTerminal(NonTerminalOperation::Scale {
610                kernel: Box::new(expression.clone()),
611                scale,
612            }));
613        }
614
615        Ok(candidates)
616    }
617
618    /// Tournament selection for genetic algorithm
619    fn tournament_selection<'a>(
620        &self,
621        population: &'a [(KernelGrammar, f64)],
622        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
623    ) -> &'a (KernelGrammar, f64) {
624        let tournament_size = 3.min(population.len());
625        let mut best_idx = rng.gen_range(0..population.len());
626        let mut best_score = population[best_idx].1;
627
628        for _ in 1..tournament_size {
629            let idx = rng.gen_range(0..population.len());
630            if population[idx].1 < best_score {
631                best_idx = idx;
632                best_score = population[idx].1;
633            }
634        }
635
636        &population[best_idx]
637    }
638
639    /// Crossover operation for genetic algorithm
640    fn crossover(
641        &self,
642        parent1: &KernelGrammar,
643        parent2: &KernelGrammar,
644        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
645    ) -> SklResult<KernelGrammar> {
646        if rng.gen::<f64>() < 0.5 {
647            Ok(KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
648                left: Box::new(parent1.clone()),
649                right: Box::new(parent2.clone()),
650            }))
651        } else {
652            Ok(KernelGrammar::NonTerminal(NonTerminalOperation::Product {
653                left: Box::new(parent1.clone()),
654                right: Box::new(parent2.clone()),
655            }))
656        }
657    }
658
659    /// Mutation operation for genetic algorithm
660    fn mutate(
661        &self,
662        expression: &KernelGrammar,
663        rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
664    ) -> SklResult<KernelGrammar> {
665        if rng.gen::<f64>() < 0.1 {
666            // Replace with random kernel
667            let new_kernels = self.generate_initial_kernels()?;
668            let idx = rng.gen_range(0..new_kernels.len());
669            Ok(new_kernels[idx].clone())
670        } else {
671            // Keep original
672            Ok(expression.clone())
673        }
674    }
675
676    /// Calculate depth of kernel expression
677    fn expression_depth(&self, expression: &KernelGrammar) -> usize {
678        match expression {
679            KernelGrammar::Terminal(_) => 1,
680            KernelGrammar::NonTerminal(op) => match op {
681                NonTerminalOperation::Sum { left, right }
682                | NonTerminalOperation::Product { left, right } => {
683                    1 + self
684                        .expression_depth(left)
685                        .max(self.expression_depth(right))
686                }
687                NonTerminalOperation::Scale { kernel, .. }
688                | NonTerminalOperation::Power { base: kernel, .. } => {
689                    1 + self.expression_depth(kernel)
690                }
691            },
692        }
693    }
694
695    /// Convert grammar expression to actual kernel
696    fn expression_to_kernel(&self, expression: &KernelGrammar) -> SklResult<Box<dyn Kernel>> {
697        match expression {
698            KernelGrammar::Terminal(terminal) => match terminal {
699                TerminalKernel::RBF { length_scale } => Ok(Box::new(RBF::new(*length_scale))),
700                TerminalKernel::Linear { sigma_0, sigma_1 } => {
701                    Ok(Box::new(Linear::new(*sigma_0, *sigma_1)))
702                }
703                TerminalKernel::Periodic {
704                    length_scale,
705                    period,
706                } => Ok(Box::new(ExpSineSquared::new(*length_scale, *period))),
707                TerminalKernel::Matern { length_scale, nu } => {
708                    Ok(Box::new(Matern::new(*length_scale, *nu)))
709                }
710                TerminalKernel::RationalQuadratic {
711                    length_scale,
712                    alpha,
713                } => Ok(Box::new(RationalQuadratic::new(*length_scale, *alpha))),
714                TerminalKernel::White { noise_level } => {
715                    Ok(Box::new(WhiteKernel::new(*noise_level)))
716                }
717                TerminalKernel::Constant { constant_value } => {
718                    Ok(Box::new(ConstantKernel::new(*constant_value)))
719                }
720            },
721            KernelGrammar::NonTerminal(op) => match op {
722                NonTerminalOperation::Sum { left, right } => {
723                    let left_kernel = self.expression_to_kernel(left)?;
724                    let right_kernel = self.expression_to_kernel(right)?;
725                    Ok(Box::new(crate::kernels::SumKernel::new(vec![
726                        left_kernel,
727                        right_kernel,
728                    ])))
729                }
730                NonTerminalOperation::Product { left, right } => {
731                    let left_kernel = self.expression_to_kernel(left)?;
732                    let right_kernel = self.expression_to_kernel(right)?;
733                    Ok(Box::new(crate::kernels::ProductKernel::new(vec![
734                        left_kernel,
735                        right_kernel,
736                    ])))
737                }
738                NonTerminalOperation::Scale { kernel, scale } => {
739                    let base_kernel = self.expression_to_kernel(kernel)?;
740                    // Scale by creating a constant kernel and using product
741                    let scale_kernel = Box::new(ConstantKernel::new(*scale));
742                    Ok(Box::new(crate::kernels::ProductKernel::new(vec![
743                        base_kernel,
744                        scale_kernel,
745                    ])))
746                }
747                NonTerminalOperation::Power { base, exponent: _ } => {
748                    // For now, just return the base kernel
749                    // Power kernels would need special implementation
750                    self.expression_to_kernel(base)
751                }
752            },
753        }
754    }
755
756    /// Convert expression to human-readable string
757    fn expression_to_string(&self, expression: &KernelGrammar) -> String {
758        match expression {
759            KernelGrammar::Terminal(terminal) => match terminal {
760                TerminalKernel::RBF { length_scale } => format!("RBF({:.3})", length_scale),
761                TerminalKernel::Linear { sigma_0, sigma_1 } => {
762                    format!("Linear({:.3}, {:.3})", sigma_0, sigma_1)
763                }
764                TerminalKernel::Periodic {
765                    length_scale,
766                    period,
767                } => format!("Periodic({:.3}, {:.3})", length_scale, period),
768                TerminalKernel::Matern { length_scale, nu } => {
769                    format!("Matern({:.3}, {:.3})", length_scale, nu)
770                }
771                TerminalKernel::RationalQuadratic {
772                    length_scale,
773                    alpha,
774                } => format!("RQ({:.3}, {:.3})", length_scale, alpha),
775                TerminalKernel::White { noise_level } => format!("White({:.3})", noise_level),
776                TerminalKernel::Constant { constant_value } => {
777                    format!("Const({:.3})", constant_value)
778                }
779            },
780            KernelGrammar::NonTerminal(op) => match op {
781                NonTerminalOperation::Sum { left, right } => {
782                    format!(
783                        "({} + {})",
784                        self.expression_to_string(left),
785                        self.expression_to_string(right)
786                    )
787                }
788                NonTerminalOperation::Product { left, right } => {
789                    format!(
790                        "({} * {})",
791                        self.expression_to_string(left),
792                        self.expression_to_string(right)
793                    )
794                }
795                NonTerminalOperation::Scale { kernel, scale } => {
796                    format!("{:.3} * {}", scale, self.expression_to_string(kernel))
797                }
798                NonTerminalOperation::Power { base, exponent } => {
799                    format!("{}^{:.3}", self.expression_to_string(base), exponent)
800                }
801            },
802        }
803    }
804
805    /// Evaluate kernel using BIC or cross-validation
806    fn evaluate_kernel(
807        &self,
808        kernel: &Box<dyn Kernel>,
809        X: &ArrayView2<f64>,
810        y: &ArrayView1<f64>,
811    ) -> SklResult<f64> {
812        // For now, use simple marginal likelihood
813        // This can be extended to use cross-validation or BIC
814        self.compute_marginal_likelihood(kernel, X, y)
815    }
816
817    /// Compute marginal likelihood for kernel evaluation
818    #[allow(non_snake_case)]
819    fn compute_marginal_likelihood(
820        &self,
821        kernel: &Box<dyn Kernel>,
822        X: &ArrayView2<f64>,
823        y: &ArrayView1<f64>,
824    ) -> SklResult<f64> {
825        let X_owned = X.to_owned();
826        let K = kernel.compute_kernel_matrix(&X_owned, Some(&X_owned))?;
827
828        // Add noise to diagonal
829        let mut K_noisy = K;
830        let noise_var = 0.1;
831        for i in 0..K_noisy.nrows() {
832            K_noisy[[i, i]] += noise_var;
833        }
834
835        // Compute Cholesky decomposition
836        match crate::utils::cholesky_decomposition(&K_noisy) {
837            Ok(L) => {
838                // Compute log marginal likelihood
839                let mut log_det = 0.0;
840                for i in 0..L.nrows() {
841                    log_det += L[[i, i]].ln();
842                }
843                log_det *= 2.0;
844
845                // Solve for alpha = K^(-1) * y
846                let y_owned = y.to_owned();
847                let alpha = match crate::utils::triangular_solve(&L, &y_owned) {
848                    Ok(temp) => {
849                        let L_T = L.t();
850                        crate::utils::triangular_solve(&L_T.view().to_owned(), &temp)?
851                    }
852                    Err(_) => return Ok(f64::INFINITY),
853                };
854
855                let data_fit = -0.5 * y.dot(&alpha);
856                let complexity_penalty = -0.5 * log_det;
857                let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
858
859                Ok(-(data_fit + complexity_penalty + normalization))
860            }
861            Err(_) => Ok(f64::INFINITY),
862        }
863    }
864}
865
866#[allow(non_snake_case)]
867#[cfg(test)]
868mod tests {
869    use super::*;
870    // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
871    use scirs2_core::ndarray::{Array1, Array2};
872
873    #[test]
874    fn test_kernel_structure_learner_creation() {
875        let learner = KernelStructureLearner::new();
876        assert_eq!(learner.max_depth, 4);
877        assert_eq!(learner.max_iterations, 100);
878    }
879
880    #[test]
881    fn test_expression_depth_calculation() {
882        let learner = KernelStructureLearner::new();
883
884        // Terminal kernel has depth 1
885        let terminal = KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
886        assert_eq!(learner.expression_depth(&terminal), 1);
887
888        // Sum of two terminals has depth 2
889        let sum = KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
890            left: Box::new(terminal.clone()),
891            right: Box::new(terminal.clone()),
892        });
893        assert_eq!(learner.expression_depth(&sum), 2);
894    }
895
896    #[test]
897    #[allow(non_snake_case)]
898    fn test_expression_to_kernel_conversion() {
899        let learner = KernelStructureLearner::new();
900
901        let expression = KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 2.0 });
902        let kernel = learner.expression_to_kernel(&expression).unwrap();
903
904        // Test that kernel can be used
905        let X = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
906        let K = kernel.compute_kernel_matrix(&X, Some(&X)).unwrap();
907        assert_eq!(K.nrows(), 3);
908        assert_eq!(K.ncols(), 3);
909    }
910
911    #[test]
912    #[allow(non_snake_case)]
913    fn test_greedy_search() {
914        let learner = KernelStructureLearner::new().max_iterations(5).max_depth(2);
915
916        let X = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
917        let y = Array1::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
918
919        let result = learner.learn_structure(X.view(), y.view()).unwrap();
920
921        assert!(result.best_score.is_finite());
922        assert!(!result.exploration_history.is_empty());
923        assert!(!result.convergence_info.score_history.is_empty());
924    }
925
926    #[test]
927    fn test_expression_to_string() {
928        let learner = KernelStructureLearner::new();
929
930        let rbf = KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
931        let linear = KernelGrammar::Terminal(TerminalKernel::Linear {
932            sigma_0: 1.0,
933            sigma_1: 1.0,
934        });
935        let sum = KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
936            left: Box::new(rbf),
937            right: Box::new(linear),
938        });
939
940        let expr_str = learner.expression_to_string(&sum);
941        assert!(expr_str.contains("RBF"));
942        assert!(expr_str.contains("Linear"));
943        assert!(expr_str.contains("+"));
944    }
945
946    #[test]
947    #[allow(non_snake_case)]
948    fn test_beam_search() {
949        let learner = KernelStructureLearner::new()
950            .max_iterations(3)
951            .max_depth(2)
952            .search_strategy(SearchStrategy::Beam { beam_width: 2 });
953
954        let X = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
955        let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
956
957        let result = learner.learn_structure(X.view(), y.view()).unwrap();
958
959        // The search might return infinity if no valid kernels are found
960        assert!(result.best_score.is_finite() || result.best_score == f64::INFINITY);
961        assert!(!result.exploration_history.is_empty());
962    }
963}