1use crate::kernels::*;
8use scirs2_core::ndarray::{ArrayView1, ArrayView2};
10use scirs2_core::random::Rng;
12use sklears_core::error::{Result as SklResult, SklearsError};
13
14#[derive(Debug, Clone)]
16pub struct KernelStructureLearner {
17 pub max_depth: usize,
19 pub max_iterations: usize,
21 pub expansion_probability: f64,
23 pub simplification_probability: f64,
25 pub improvement_threshold: f64,
27 pub use_bic: bool,
29 pub random_state: Option<u64>,
31 pub search_strategy: SearchStrategy,
33}
34
35#[derive(Debug, Clone, Copy)]
37pub enum SearchStrategy {
38 Greedy,
39 Beam {
40 beam_width: usize,
41 },
42 Genetic {
44 population_size: usize,
45 },
46 SimulatedAnnealing {
48 initial_temperature: f64,
49 },
50}
51
52#[derive(Debug, Clone)]
54pub enum KernelGrammar {
55 Terminal(TerminalKernel),
57 NonTerminal(NonTerminalOperation),
59}
60
61#[derive(Debug, Clone)]
63pub enum TerminalKernel {
64 RBF { length_scale: f64 },
66 Linear { sigma_0: f64, sigma_1: f64 },
68 Periodic { length_scale: f64, period: f64 },
70 Matern { length_scale: f64, nu: f64 },
72 RationalQuadratic { length_scale: f64, alpha: f64 },
74 White { noise_level: f64 },
76 Constant { constant_value: f64 },
78}
79
80#[derive(Debug, Clone)]
82pub enum NonTerminalOperation {
83 Sum {
85 left: Box<KernelGrammar>,
86 right: Box<KernelGrammar>,
87 },
88 Product {
90 left: Box<KernelGrammar>,
91 right: Box<KernelGrammar>,
92 },
93 Power {
95 base: Box<KernelGrammar>,
96 exponent: f64,
97 },
98 Scale {
100 kernel: Box<KernelGrammar>,
101 scale: f64,
102 },
103}
104
105#[derive(Debug, Clone)]
107pub struct StructureLearningResult {
108 pub best_kernel: Box<dyn Kernel>,
110 pub best_expression: KernelGrammar,
112 pub best_score: f64,
114 pub exploration_history: Vec<(String, f64)>,
116 pub convergence_info: ConvergenceInfo,
118}
119
120#[derive(Debug, Clone)]
122pub struct ConvergenceInfo {
123 pub iterations: usize,
125 pub score_history: Vec<f64>,
127 pub converged: bool,
129 pub final_temperature: Option<f64>,
131}
132
133impl Default for KernelStructureLearner {
134 fn default() -> Self {
135 Self {
136 max_depth: 4,
137 max_iterations: 100,
138 expansion_probability: 0.7,
139 simplification_probability: 0.3,
140 improvement_threshold: 0.01,
141 use_bic: true,
142 random_state: Some(42),
143 search_strategy: SearchStrategy::Greedy,
144 }
145 }
146}
147
148impl KernelStructureLearner {
149 pub fn new() -> Self {
151 Self::default()
152 }
153
154 pub fn max_depth(mut self, depth: usize) -> Self {
156 self.max_depth = depth;
157 self
158 }
159
160 pub fn max_iterations(mut self, iterations: usize) -> Self {
162 self.max_iterations = iterations;
163 self
164 }
165
166 pub fn search_strategy(mut self, strategy: SearchStrategy) -> Self {
168 self.search_strategy = strategy;
169 self
170 }
171
172 pub fn random_state(mut self, seed: Option<u64>) -> Self {
174 self.random_state = seed;
175 self
176 }
177
178 pub fn learn_structure(
180 &self,
181 X: ArrayView2<f64>,
182 y: ArrayView1<f64>,
183 ) -> SklResult<StructureLearningResult> {
184 let mut rng = if let Some(seed) = self.random_state {
186 scirs2_core::random::Random::seed(seed)
187 } else {
188 scirs2_core::random::Random::seed(42)
189 };
190
191 match self.search_strategy {
192 SearchStrategy::Greedy => self.greedy_search(X, y, &mut rng),
193 SearchStrategy::Beam { beam_width } => self.beam_search(X, y, beam_width, &mut rng),
194 SearchStrategy::Genetic { population_size } => {
195 self.genetic_search(X, y, population_size, &mut rng)
196 }
197 SearchStrategy::SimulatedAnnealing {
198 initial_temperature,
199 } => self.simulated_annealing_search(X, y, initial_temperature, &mut rng),
200 }
201 }
202
203 fn greedy_search(
205 &self,
206 X: ArrayView2<f64>,
207 y: ArrayView1<f64>,
208 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
209 ) -> SklResult<StructureLearningResult> {
210 let mut exploration_history = Vec::new();
211 let mut score_history = Vec::new();
212
213 let mut current_expression =
215 KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
216 let mut current_kernel = self.expression_to_kernel(¤t_expression)?;
217 let mut current_score = self.evaluate_kernel(¤t_kernel, &X, &y)?;
218
219 exploration_history.push(("RBF".to_string(), current_score));
220 score_history.push(current_score);
221
222 let mut best_expression = current_expression.clone();
223 let mut best_kernel = current_kernel.clone();
224 let mut best_score = current_score;
225
226 for _iteration in 0..self.max_iterations {
227 let candidates = self.generate_candidates(¤t_expression, rng)?;
229
230 let mut found_improvement = false;
231
232 for candidate in candidates {
233 if self.expression_depth(&candidate) > self.max_depth {
234 continue;
235 }
236
237 let kernel = self.expression_to_kernel(&candidate)?;
238 let score = self.evaluate_kernel(&kernel, &X, &y)?;
239
240 let expression_str = self.expression_to_string(&candidate);
241 exploration_history.push((expression_str, score));
242
243 if score < current_score - self.improvement_threshold {
244 current_expression = candidate;
245 current_kernel = kernel;
246 current_score = score;
247 found_improvement = true;
248
249 if score < best_score {
250 best_expression = current_expression.clone();
251 best_kernel = current_kernel.clone();
252 best_score = score;
253 }
254 break;
255 }
256 }
257
258 score_history.push(current_score);
259
260 if !found_improvement {
261 break;
262 }
263 }
264
265 Ok(StructureLearningResult {
266 best_kernel,
267 best_expression,
268 best_score,
269 exploration_history,
270 convergence_info: ConvergenceInfo {
271 iterations: score_history.len(),
272 score_history,
273 converged: true,
274 final_temperature: None,
275 },
276 })
277 }
278
279 fn beam_search(
281 &self,
282 X: ArrayView2<f64>,
283 y: ArrayView1<f64>,
284 beam_width: usize,
285 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
286 ) -> SklResult<StructureLearningResult> {
287 let mut exploration_history = Vec::new();
288 let mut score_history = Vec::new();
289
290 let mut beam: Vec<(KernelGrammar, f64)> = Vec::new();
292 let initial_kernels = self.generate_initial_kernels()?;
293
294 if initial_kernels.is_empty() {
295 return Err(SklearsError::InvalidOperation(
296 "No initial kernels generated".to_string(),
297 ));
298 }
299
300 for kernel_expr in initial_kernels {
301 let kernel = self.expression_to_kernel(&kernel_expr)?;
302 let score = self.evaluate_kernel(&kernel, &X, &y)?;
303 beam.push((kernel_expr.clone(), score));
304
305 let expr_str = self.expression_to_string(&kernel_expr);
306 exploration_history.push((expr_str, score));
307 }
308
309 beam.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
311 beam.truncate(beam_width);
312
313 if beam.is_empty() {
314 return Err(SklearsError::InvalidOperation(
315 "Beam is empty after initialization".to_string(),
316 ));
317 }
318
319 let mut best_score = beam[0].1;
320 score_history.push(best_score);
321
322 for _iteration in 0..self.max_iterations {
323 let mut new_beam = Vec::new();
324
325 for (expression, _) in &beam {
327 let candidates = self.generate_candidates(expression, rng)?;
328
329 for candidate in candidates {
330 if self.expression_depth(&candidate) > self.max_depth {
331 continue;
332 }
333
334 let kernel = self.expression_to_kernel(&candidate)?;
335 let score = self.evaluate_kernel(&kernel, &X, &y)?;
336
337 new_beam.push((candidate.clone(), score));
338
339 let expr_str = self.expression_to_string(&candidate);
340 exploration_history.push((expr_str, score));
341 }
342 }
343
344 new_beam.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
346 new_beam.truncate(beam_width);
347
348 beam = new_beam;
350
351 if beam.is_empty() {
352 break;
353 }
354
355 let current_best = beam[0].1;
356 score_history.push(current_best);
357
358 if (best_score - current_best).abs() < self.improvement_threshold {
359 break;
360 }
361
362 best_score = current_best;
363 }
364
365 if beam.is_empty() {
366 let default_expression =
368 KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
369 let default_kernel = self.expression_to_kernel(&default_expression)?;
370 return Ok(StructureLearningResult {
371 best_kernel: default_kernel,
372 best_expression: default_expression,
373 best_score: f64::INFINITY,
374 exploration_history,
375 convergence_info: ConvergenceInfo {
376 iterations: score_history.len(),
377 score_history,
378 converged: false,
379 final_temperature: None,
380 },
381 });
382 }
383
384 let best_expression = beam[0].0.clone();
385 let best_kernel = self.expression_to_kernel(&best_expression)?;
386
387 Ok(StructureLearningResult {
388 best_kernel,
389 best_expression,
390 best_score: beam[0].1,
391 exploration_history,
392 convergence_info: ConvergenceInfo {
393 iterations: score_history.len(),
394 score_history,
395 converged: true,
396 final_temperature: None,
397 },
398 })
399 }
400
401 fn genetic_search(
403 &self,
404 X: ArrayView2<f64>,
405 y: ArrayView1<f64>,
406 population_size: usize,
407 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
408 ) -> SklResult<StructureLearningResult> {
409 let mut exploration_history = Vec::new();
410 let mut score_history = Vec::new();
411
412 let mut population: Vec<(KernelGrammar, f64)> = Vec::new();
414 let initial_kernels = self.generate_initial_kernels()?;
415
416 for _ in 0..population_size {
417 let idx = rng.gen_range(0..initial_kernels.len());
418 let kernel_expr = initial_kernels[idx].clone();
419 let kernel = self.expression_to_kernel(&kernel_expr)?;
420 let score = self.evaluate_kernel(&kernel, &X, &y)?;
421 population.push((kernel_expr, score));
422 }
423
424 for _generation in 0..self.max_iterations {
425 let mut new_population = Vec::new();
427
428 for _ in 0..population_size {
429 let parent1 = self.tournament_selection(&population, rng);
430 let parent2 = self.tournament_selection(&population, rng);
431
432 let child = self.crossover(&parent1.0, &parent2.0, rng)?;
434
435 let mutated_child = self.mutate(&child, rng)?;
437
438 if self.expression_depth(&mutated_child) <= self.max_depth {
439 let kernel = self.expression_to_kernel(&mutated_child)?;
440 let score = self.evaluate_kernel(&kernel, &X, &y)?;
441 new_population.push((mutated_child.clone(), score));
442
443 let expr_str = self.expression_to_string(&mutated_child);
444 exploration_history.push((expr_str, score));
445 }
446 }
447
448 population = new_population;
450 population.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
451 population.truncate(population_size);
452
453 let best_score = population[0].1;
454 score_history.push(best_score);
455 }
456
457 let best_expression = population[0].0.clone();
458 let best_kernel = self.expression_to_kernel(&best_expression)?;
459
460 Ok(StructureLearningResult {
461 best_kernel,
462 best_expression,
463 best_score: population[0].1,
464 exploration_history,
465 convergence_info: ConvergenceInfo {
466 iterations: score_history.len(),
467 score_history,
468 converged: true,
469 final_temperature: None,
470 },
471 })
472 }
473
474 fn simulated_annealing_search(
476 &self,
477 X: ArrayView2<f64>,
478 y: ArrayView1<f64>,
479 initial_temperature: f64,
480 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
481 ) -> SklResult<StructureLearningResult> {
482 let mut exploration_history = Vec::new();
483 let mut score_history = Vec::new();
484
485 let mut current_expression =
487 KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
488 let mut current_kernel = self.expression_to_kernel(¤t_expression)?;
489 let mut current_score = self.evaluate_kernel(¤t_kernel, &X, &y)?;
490
491 let mut best_expression = current_expression.clone();
492 let mut best_kernel = current_kernel.clone();
493 let mut best_score = current_score;
494
495 let mut temperature = initial_temperature;
496 let cooling_rate = 0.95;
497
498 for _iteration in 0..self.max_iterations {
499 let candidates = self.generate_candidates(¤t_expression, rng)?;
501 if candidates.is_empty() {
502 break;
503 }
504
505 let idx = rng.gen_range(0..candidates.len());
506 let candidate = &candidates[idx];
507 if self.expression_depth(candidate) > self.max_depth {
508 continue;
509 }
510
511 let kernel = self.expression_to_kernel(candidate)?;
512 let score = self.evaluate_kernel(&kernel, &X, &y)?;
513
514 let expr_str = self.expression_to_string(candidate);
515 exploration_history.push((expr_str, score));
516
517 let delta = score - current_score;
519 if delta < 0.0 || rng.gen::<f64>() < (-delta / temperature).exp() {
520 current_expression = candidate.clone();
521 current_kernel = kernel;
522 current_score = score;
523
524 if score < best_score {
525 best_expression = current_expression.clone();
526 best_kernel = current_kernel.clone();
527 best_score = score;
528 }
529 }
530
531 score_history.push(current_score);
532 temperature *= cooling_rate;
533 }
534
535 Ok(StructureLearningResult {
536 best_kernel,
537 best_expression,
538 best_score,
539 exploration_history,
540 convergence_info: ConvergenceInfo {
541 iterations: score_history.len(),
542 score_history,
543 converged: true,
544 final_temperature: Some(temperature),
545 },
546 })
547 }
548
549 fn generate_initial_kernels(&self) -> SklResult<Vec<KernelGrammar>> {
551 Ok(vec![
552 KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 }),
553 KernelGrammar::Terminal(TerminalKernel::Linear {
554 sigma_0: 1.0,
555 sigma_1: 1.0,
556 }),
557 KernelGrammar::Terminal(TerminalKernel::Matern {
558 length_scale: 1.0,
559 nu: 1.5,
560 }),
561 KernelGrammar::Terminal(TerminalKernel::RationalQuadratic {
562 length_scale: 1.0,
563 alpha: 1.0,
564 }),
565 KernelGrammar::Terminal(TerminalKernel::Periodic {
566 length_scale: 1.0,
567 period: 1.0,
568 }),
569 KernelGrammar::Terminal(TerminalKernel::White { noise_level: 0.1 }),
570 KernelGrammar::Terminal(TerminalKernel::Constant {
571 constant_value: 1.0,
572 }),
573 ])
574 }
575
576 fn generate_candidates(
578 &self,
579 expression: &KernelGrammar,
580 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
581 ) -> SklResult<Vec<KernelGrammar>> {
582 let mut candidates = Vec::new();
583
584 if rng.gen::<f64>() < self.expansion_probability {
586 let new_kernels = self.generate_initial_kernels()?;
587 for new_kernel in new_kernels {
588 candidates.push(KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
589 left: Box::new(expression.clone()),
590 right: Box::new(new_kernel),
591 }));
592 }
593 }
594
595 if rng.gen::<f64>() < self.expansion_probability {
597 let new_kernels = self.generate_initial_kernels()?;
598 for new_kernel in new_kernels {
599 candidates.push(KernelGrammar::NonTerminal(NonTerminalOperation::Product {
600 left: Box::new(expression.clone()),
601 right: Box::new(new_kernel),
602 }));
603 }
604 }
605
606 if rng.gen::<f64>() < 0.3 {
608 let scale = rng.gen_range(0.1..10.0);
609 candidates.push(KernelGrammar::NonTerminal(NonTerminalOperation::Scale {
610 kernel: Box::new(expression.clone()),
611 scale,
612 }));
613 }
614
615 Ok(candidates)
616 }
617
618 fn tournament_selection<'a>(
620 &self,
621 population: &'a [(KernelGrammar, f64)],
622 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
623 ) -> &'a (KernelGrammar, f64) {
624 let tournament_size = 3.min(population.len());
625 let mut best_idx = rng.gen_range(0..population.len());
626 let mut best_score = population[best_idx].1;
627
628 for _ in 1..tournament_size {
629 let idx = rng.gen_range(0..population.len());
630 if population[idx].1 < best_score {
631 best_idx = idx;
632 best_score = population[idx].1;
633 }
634 }
635
636 &population[best_idx]
637 }
638
639 fn crossover(
641 &self,
642 parent1: &KernelGrammar,
643 parent2: &KernelGrammar,
644 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
645 ) -> SklResult<KernelGrammar> {
646 if rng.gen::<f64>() < 0.5 {
647 Ok(KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
648 left: Box::new(parent1.clone()),
649 right: Box::new(parent2.clone()),
650 }))
651 } else {
652 Ok(KernelGrammar::NonTerminal(NonTerminalOperation::Product {
653 left: Box::new(parent1.clone()),
654 right: Box::new(parent2.clone()),
655 }))
656 }
657 }
658
659 fn mutate(
661 &self,
662 expression: &KernelGrammar,
663 rng: &mut scirs2_core::random::Random<scirs2_core::rngs::StdRng>,
664 ) -> SklResult<KernelGrammar> {
665 if rng.gen::<f64>() < 0.1 {
666 let new_kernels = self.generate_initial_kernels()?;
668 let idx = rng.gen_range(0..new_kernels.len());
669 Ok(new_kernels[idx].clone())
670 } else {
671 Ok(expression.clone())
673 }
674 }
675
676 fn expression_depth(&self, expression: &KernelGrammar) -> usize {
678 match expression {
679 KernelGrammar::Terminal(_) => 1,
680 KernelGrammar::NonTerminal(op) => match op {
681 NonTerminalOperation::Sum { left, right }
682 | NonTerminalOperation::Product { left, right } => {
683 1 + self
684 .expression_depth(left)
685 .max(self.expression_depth(right))
686 }
687 NonTerminalOperation::Scale { kernel, .. }
688 | NonTerminalOperation::Power { base: kernel, .. } => {
689 1 + self.expression_depth(kernel)
690 }
691 },
692 }
693 }
694
695 fn expression_to_kernel(&self, expression: &KernelGrammar) -> SklResult<Box<dyn Kernel>> {
697 match expression {
698 KernelGrammar::Terminal(terminal) => match terminal {
699 TerminalKernel::RBF { length_scale } => Ok(Box::new(RBF::new(*length_scale))),
700 TerminalKernel::Linear { sigma_0, sigma_1 } => {
701 Ok(Box::new(Linear::new(*sigma_0, *sigma_1)))
702 }
703 TerminalKernel::Periodic {
704 length_scale,
705 period,
706 } => Ok(Box::new(ExpSineSquared::new(*length_scale, *period))),
707 TerminalKernel::Matern { length_scale, nu } => {
708 Ok(Box::new(Matern::new(*length_scale, *nu)))
709 }
710 TerminalKernel::RationalQuadratic {
711 length_scale,
712 alpha,
713 } => Ok(Box::new(RationalQuadratic::new(*length_scale, *alpha))),
714 TerminalKernel::White { noise_level } => {
715 Ok(Box::new(WhiteKernel::new(*noise_level)))
716 }
717 TerminalKernel::Constant { constant_value } => {
718 Ok(Box::new(ConstantKernel::new(*constant_value)))
719 }
720 },
721 KernelGrammar::NonTerminal(op) => match op {
722 NonTerminalOperation::Sum { left, right } => {
723 let left_kernel = self.expression_to_kernel(left)?;
724 let right_kernel = self.expression_to_kernel(right)?;
725 Ok(Box::new(crate::kernels::SumKernel::new(vec![
726 left_kernel,
727 right_kernel,
728 ])))
729 }
730 NonTerminalOperation::Product { left, right } => {
731 let left_kernel = self.expression_to_kernel(left)?;
732 let right_kernel = self.expression_to_kernel(right)?;
733 Ok(Box::new(crate::kernels::ProductKernel::new(vec![
734 left_kernel,
735 right_kernel,
736 ])))
737 }
738 NonTerminalOperation::Scale { kernel, scale } => {
739 let base_kernel = self.expression_to_kernel(kernel)?;
740 let scale_kernel = Box::new(ConstantKernel::new(*scale));
742 Ok(Box::new(crate::kernels::ProductKernel::new(vec![
743 base_kernel,
744 scale_kernel,
745 ])))
746 }
747 NonTerminalOperation::Power { base, exponent: _ } => {
748 self.expression_to_kernel(base)
751 }
752 },
753 }
754 }
755
756 fn expression_to_string(&self, expression: &KernelGrammar) -> String {
758 match expression {
759 KernelGrammar::Terminal(terminal) => match terminal {
760 TerminalKernel::RBF { length_scale } => format!("RBF({:.3})", length_scale),
761 TerminalKernel::Linear { sigma_0, sigma_1 } => {
762 format!("Linear({:.3}, {:.3})", sigma_0, sigma_1)
763 }
764 TerminalKernel::Periodic {
765 length_scale,
766 period,
767 } => format!("Periodic({:.3}, {:.3})", length_scale, period),
768 TerminalKernel::Matern { length_scale, nu } => {
769 format!("Matern({:.3}, {:.3})", length_scale, nu)
770 }
771 TerminalKernel::RationalQuadratic {
772 length_scale,
773 alpha,
774 } => format!("RQ({:.3}, {:.3})", length_scale, alpha),
775 TerminalKernel::White { noise_level } => format!("White({:.3})", noise_level),
776 TerminalKernel::Constant { constant_value } => {
777 format!("Const({:.3})", constant_value)
778 }
779 },
780 KernelGrammar::NonTerminal(op) => match op {
781 NonTerminalOperation::Sum { left, right } => {
782 format!(
783 "({} + {})",
784 self.expression_to_string(left),
785 self.expression_to_string(right)
786 )
787 }
788 NonTerminalOperation::Product { left, right } => {
789 format!(
790 "({} * {})",
791 self.expression_to_string(left),
792 self.expression_to_string(right)
793 )
794 }
795 NonTerminalOperation::Scale { kernel, scale } => {
796 format!("{:.3} * {}", scale, self.expression_to_string(kernel))
797 }
798 NonTerminalOperation::Power { base, exponent } => {
799 format!("{}^{:.3}", self.expression_to_string(base), exponent)
800 }
801 },
802 }
803 }
804
805 fn evaluate_kernel(
807 &self,
808 kernel: &Box<dyn Kernel>,
809 X: &ArrayView2<f64>,
810 y: &ArrayView1<f64>,
811 ) -> SklResult<f64> {
812 self.compute_marginal_likelihood(kernel, X, y)
815 }
816
817 #[allow(non_snake_case)]
819 fn compute_marginal_likelihood(
820 &self,
821 kernel: &Box<dyn Kernel>,
822 X: &ArrayView2<f64>,
823 y: &ArrayView1<f64>,
824 ) -> SklResult<f64> {
825 let X_owned = X.to_owned();
826 let K = kernel.compute_kernel_matrix(&X_owned, Some(&X_owned))?;
827
828 let mut K_noisy = K;
830 let noise_var = 0.1;
831 for i in 0..K_noisy.nrows() {
832 K_noisy[[i, i]] += noise_var;
833 }
834
835 match crate::utils::cholesky_decomposition(&K_noisy) {
837 Ok(L) => {
838 let mut log_det = 0.0;
840 for i in 0..L.nrows() {
841 log_det += L[[i, i]].ln();
842 }
843 log_det *= 2.0;
844
845 let y_owned = y.to_owned();
847 let alpha = match crate::utils::triangular_solve(&L, &y_owned) {
848 Ok(temp) => {
849 let L_T = L.t();
850 crate::utils::triangular_solve(&L_T.view().to_owned(), &temp)?
851 }
852 Err(_) => return Ok(f64::INFINITY),
853 };
854
855 let data_fit = -0.5 * y.dot(&alpha);
856 let complexity_penalty = -0.5 * log_det;
857 let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI).ln();
858
859 Ok(-(data_fit + complexity_penalty + normalization))
860 }
861 Err(_) => Ok(f64::INFINITY),
862 }
863 }
864}
865
866#[allow(non_snake_case)]
867#[cfg(test)]
868mod tests {
869 use super::*;
870 use scirs2_core::ndarray::{Array1, Array2};
872
873 #[test]
874 fn test_kernel_structure_learner_creation() {
875 let learner = KernelStructureLearner::new();
876 assert_eq!(learner.max_depth, 4);
877 assert_eq!(learner.max_iterations, 100);
878 }
879
880 #[test]
881 fn test_expression_depth_calculation() {
882 let learner = KernelStructureLearner::new();
883
884 let terminal = KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
886 assert_eq!(learner.expression_depth(&terminal), 1);
887
888 let sum = KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
890 left: Box::new(terminal.clone()),
891 right: Box::new(terminal.clone()),
892 });
893 assert_eq!(learner.expression_depth(&sum), 2);
894 }
895
896 #[test]
897 #[allow(non_snake_case)]
898 fn test_expression_to_kernel_conversion() {
899 let learner = KernelStructureLearner::new();
900
901 let expression = KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 2.0 });
902 let kernel = learner.expression_to_kernel(&expression).unwrap();
903
904 let X = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
906 let K = kernel.compute_kernel_matrix(&X, Some(&X)).unwrap();
907 assert_eq!(K.nrows(), 3);
908 assert_eq!(K.ncols(), 3);
909 }
910
911 #[test]
912 #[allow(non_snake_case)]
913 fn test_greedy_search() {
914 let learner = KernelStructureLearner::new().max_iterations(5).max_depth(2);
915
916 let X = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
917 let y = Array1::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
918
919 let result = learner.learn_structure(X.view(), y.view()).unwrap();
920
921 assert!(result.best_score.is_finite());
922 assert!(!result.exploration_history.is_empty());
923 assert!(!result.convergence_info.score_history.is_empty());
924 }
925
926 #[test]
927 fn test_expression_to_string() {
928 let learner = KernelStructureLearner::new();
929
930 let rbf = KernelGrammar::Terminal(TerminalKernel::RBF { length_scale: 1.0 });
931 let linear = KernelGrammar::Terminal(TerminalKernel::Linear {
932 sigma_0: 1.0,
933 sigma_1: 1.0,
934 });
935 let sum = KernelGrammar::NonTerminal(NonTerminalOperation::Sum {
936 left: Box::new(rbf),
937 right: Box::new(linear),
938 });
939
940 let expr_str = learner.expression_to_string(&sum);
941 assert!(expr_str.contains("RBF"));
942 assert!(expr_str.contains("Linear"));
943 assert!(expr_str.contains("+"));
944 }
945
946 #[test]
947 #[allow(non_snake_case)]
948 fn test_beam_search() {
949 let learner = KernelStructureLearner::new()
950 .max_iterations(3)
951 .max_depth(2)
952 .search_strategy(SearchStrategy::Beam { beam_width: 2 });
953
954 let X = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
955 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
956
957 let result = learner.learn_structure(X.view(), y.view()).unwrap();
958
959 assert!(result.best_score.is_finite() || result.best_score == f64::INFINITY);
961 assert!(!result.exploration_history.is_empty());
962 }
963}