1use crate::cost_curve::CostCurvePoint;
27use crate::transfer::{ArmId, BetaParams, ContextBucket};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct BucketRegret {
38 pub best_mean: f32,
40 pub best_arm: ArmId,
42 pub cumulative_regret: f64,
44 pub observations: u64,
46 pub regret_history: Vec<f64>,
48 arm_means: HashMap<ArmId, (f64, u64)>,
50}
51
52impl BucketRegret {
53 fn new() -> Self {
54 Self {
55 best_mean: 0.0,
56 best_arm: ArmId("unknown".into()),
57 cumulative_regret: 0.0,
58 observations: 0,
59 regret_history: Vec::new(),
60 arm_means: HashMap::new(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct RegretTracker {
72 buckets: HashMap<ContextBucket, BucketRegret>,
73 pub total_regret: f64,
75 pub total_observations: u64,
77 snapshot_interval: u64,
79}
80
81impl RegretTracker {
82 pub fn new(snapshot_interval: u64) -> Self {
84 Self {
85 buckets: HashMap::new(),
86 total_regret: 0.0,
87 total_observations: 0,
88 snapshot_interval: snapshot_interval.max(1),
89 }
90 }
91
92 pub fn record(
94 &mut self,
95 bucket: &ContextBucket,
96 arm: &ArmId,
97 reward: f32,
98 ) {
99 if !self.buckets.contains_key(bucket) {
101 self.buckets.insert(bucket.clone(), BucketRegret::new());
102 }
103 let entry = self.buckets.get_mut(bucket).unwrap();
104
105 if !entry.arm_means.contains_key(arm) {
107 entry.arm_means.insert(arm.clone(), (0.0, 0));
108 }
109 let (sum, count) = entry.arm_means.get_mut(arm).unwrap();
110 *sum += reward as f64;
111 *count += 1;
112 let arm_mean = *sum / *count as f64;
113
114 if arm_mean > entry.best_mean as f64 {
116 entry.best_mean = arm_mean as f32;
117 entry.best_arm = arm.clone();
118 }
119
120 let instant_regret = (entry.best_mean as f64 - reward as f64).max(0.0);
122 entry.cumulative_regret += instant_regret;
123 entry.observations += 1;
124 self.total_regret += instant_regret;
125 self.total_observations += 1;
126
127 if entry.observations % self.snapshot_interval == 0 {
129 entry.regret_history.push(entry.cumulative_regret);
130 }
131 }
132
133 pub fn regret_growth_rate(&self, bucket: &ContextBucket) -> Option<f32> {
138 let entry = self.buckets.get(bucket)?;
139 if entry.observations < 10 || entry.cumulative_regret < 1e-10 {
140 return None;
141 }
142 let log_regret = (entry.cumulative_regret).ln();
143 let log_t = (entry.observations as f64).ln();
144 Some((log_regret / log_t) as f32)
145 }
146
147 pub fn average_regret(&self) -> f64 {
149 if self.total_observations == 0 {
150 return 0.0;
151 }
152 self.total_regret / self.total_observations as f64
153 }
154
155 pub fn has_converged(&self, bucket: &ContextBucket, threshold: f32) -> bool {
157 self.regret_growth_rate(bucket)
158 .map_or(false, |rate| rate < threshold)
159 }
160
161 pub fn summary(&self) -> RegretSummary {
163 let bucket_rates: Vec<(ContextBucket, f32)> = self
164 .buckets
165 .keys()
166 .filter_map(|b| self.regret_growth_rate(b).map(|r| (b.clone(), r)))
167 .collect();
168
169 let mean_rate = if bucket_rates.is_empty() {
170 1.0
171 } else {
172 bucket_rates.iter().map(|(_, r)| r).sum::<f32>() / bucket_rates.len() as f32
173 };
174
175 RegretSummary {
176 total_regret: self.total_regret,
177 total_observations: self.total_observations,
178 average_regret: self.average_regret(),
179 mean_growth_rate: mean_rate,
180 bucket_count: self.buckets.len(),
181 converged_buckets: bucket_rates.iter().filter(|(_, r)| *r < 0.7).count(),
182 }
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct RegretSummary {
189 pub total_regret: f64,
190 pub total_observations: u64,
191 pub average_regret: f64,
192 pub mean_growth_rate: f32,
194 pub bucket_count: usize,
195 pub converged_buckets: usize,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct DecayingBeta {
214 pub alpha: f32,
215 pub beta: f32,
216 pub decay_factor: f32,
218 pub effective_n: f32,
220}
221
222impl DecayingBeta {
223 pub fn new(decay_factor: f32) -> Self {
225 Self {
226 alpha: 1.0,
227 beta: 1.0,
228 decay_factor: decay_factor.clamp(0.9, 1.0),
229 effective_n: 0.0,
230 }
231 }
232
233 pub fn from_beta(params: &BetaParams, decay_factor: f32) -> Self {
235 Self {
236 alpha: params.alpha,
237 beta: params.beta,
238 decay_factor: decay_factor.clamp(0.9, 1.0),
239 effective_n: params.alpha + params.beta - 2.0,
240 }
241 }
242
243 pub fn update(&mut self, reward: f32) {
245 self.alpha = 1.0 + (self.alpha - 1.0) * self.decay_factor;
247 self.beta = 1.0 + (self.beta - 1.0) * self.decay_factor;
248
249 self.alpha += reward;
251 self.beta += 1.0 - reward;
252
253 self.effective_n = self.effective_n * self.decay_factor + 1.0;
255 }
256
257 pub fn mean(&self) -> f32 {
259 self.alpha / (self.alpha + self.beta)
260 }
261
262 pub fn variance(&self) -> f32 {
264 let total = self.alpha + self.beta;
265 (self.alpha * self.beta) / (total * total * (total + 1.0))
266 }
267
268 pub fn to_beta_params(&self) -> BetaParams {
270 BetaParams {
271 alpha: self.alpha,
272 beta: self.beta,
273 }
274 }
275
276 pub fn effective_window(&self) -> f32 {
278 if self.decay_factor >= 1.0 {
279 self.effective_n
280 } else {
281 1.0 / (1.0 - self.decay_factor)
282 }
283 }
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct PlateauDetector {
297 pub window_size: usize,
299 pub improvement_threshold: f32,
301 pub consecutive_plateaus: u32,
303 pub total_plateaus: u32,
305}
306
307#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
309pub enum PlateauAction {
310 Continue,
312 IncreaseExploration,
314 TriggerTransfer,
316 InjectDiversity,
318 Reset,
320}
321
322impl PlateauDetector {
323 pub fn new(window_size: usize, improvement_threshold: f32) -> Self {
325 Self {
326 window_size: window_size.max(3),
327 improvement_threshold: improvement_threshold.max(0.001),
328 consecutive_plateaus: 0,
329 total_plateaus: 0,
330 }
331 }
332
333 pub fn check(&mut self, points: &[CostCurvePoint]) -> PlateauAction {
335 if points.len() < self.window_size * 2 {
336 self.consecutive_plateaus = 0;
337 return PlateauAction::Continue;
338 }
339
340 let n = points.len();
341 let recent = &points[n - self.window_size..];
342 let prior = &points[n - 2 * self.window_size..n - self.window_size];
343
344 let recent_mean = recent.iter().map(|p| p.accuracy).sum::<f32>()
345 / recent.len() as f32;
346 let prior_mean = prior.iter().map(|p| p.accuracy).sum::<f32>()
347 / prior.len() as f32;
348
349 let improvement = recent_mean - prior_mean;
350
351 if improvement.abs() < self.improvement_threshold {
352 self.consecutive_plateaus += 1;
353 self.total_plateaus += 1;
354
355 match self.consecutive_plateaus {
356 1 => PlateauAction::IncreaseExploration,
357 2..=3 => PlateauAction::TriggerTransfer,
358 4..=6 => PlateauAction::InjectDiversity,
359 _ => PlateauAction::Reset,
360 }
361 } else {
362 self.consecutive_plateaus = 0;
363 PlateauAction::Continue
364 }
365 }
366
367 pub fn check_cost(&self, points: &[CostCurvePoint]) -> bool {
369 if points.len() < self.window_size * 2 {
370 return false;
371 }
372
373 let n = points.len();
374 let recent = &points[n - self.window_size..];
375 let prior = &points[n - 2 * self.window_size..n - self.window_size];
376
377 let recent_cost = recent.iter().map(|p| p.cost_per_solve).sum::<f32>()
378 / recent.len() as f32;
379 let prior_cost = prior.iter().map(|p| p.cost_per_solve).sum::<f32>()
380 / prior.len() as f32;
381
382 (prior_cost - recent_cost).abs() < self.improvement_threshold
384 }
385
386 pub fn learning_velocity(&self, points: &[CostCurvePoint]) -> f32 {
388 if points.len() < 2 {
389 return 0.0;
390 }
391 let n = points.len();
392 let window = self.window_size.min(n);
393 let recent = &points[n - window..];
394
395 if recent.len() < 2 {
396 return 0.0;
397 }
398
399 let first = recent.first().unwrap();
400 let last = recent.last().unwrap();
401 let dt = (last.cycle - first.cycle).max(1) as f32;
402
403 (last.accuracy - first.accuracy) / dt
404 }
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
413pub struct ParetoPoint {
414 pub kernel_id: String,
416 pub objectives: Vec<f32>,
419 pub generation: u32,
421}
422
423#[derive(Debug, Clone, Default, Serialize, Deserialize)]
429pub struct ParetoFront {
430 pub front: Vec<ParetoPoint>,
432 pub evaluated: u64,
434 pub front_updates: u64,
436}
437
438impl ParetoFront {
439 pub fn new() -> Self {
440 Self::default()
441 }
442
443 pub fn dominates(a: &[f32], b: &[f32]) -> bool {
448 if a.len() != b.len() {
449 return false;
450 }
451 let mut at_least_equal = true;
452 let mut strictly_better = false;
453
454 for (ai, bi) in a.iter().zip(b.iter()) {
455 if ai < bi {
456 at_least_equal = false;
457 break;
458 }
459 if ai > bi {
460 strictly_better = true;
461 }
462 }
463
464 at_least_equal && strictly_better
465 }
466
467 pub fn insert(&mut self, point: ParetoPoint) -> bool {
471 self.evaluated += 1;
472
473 for existing in &self.front {
475 if Self::dominates(&existing.objectives, &point.objectives) {
476 return false; }
478 }
479
480 self.front
482 .retain(|existing| !Self::dominates(&point.objectives, &existing.objectives));
483
484 self.front.push(point);
485 self.front_updates += 1;
486 true
487 }
488
489 pub fn hypervolume(&self, reference: &[f32]) -> f32 {
495 if self.front.is_empty() || reference.is_empty() {
496 return 0.0;
497 }
498
499 let dim = reference.len();
500 if dim == 2 {
501 self.hypervolume_2d(reference)
502 } else {
503 self.front
505 .iter()
506 .map(|p| {
507 p.objectives
508 .iter()
509 .zip(reference.iter())
510 .map(|(oi, ri)| (oi - ri).max(0.0))
511 .product::<f32>()
512 })
513 .sum()
514 }
515 }
516
517 fn hypervolume_2d(&self, reference: &[f32]) -> f32 {
519 if self.front.is_empty() {
520 return 0.0;
521 }
522
523 let mut points: Vec<(f32, f32)> = self
524 .front
525 .iter()
526 .map(|p| {
527 let x = p.objectives.first().copied().unwrap_or(0.0);
528 let y = p.objectives.get(1).copied().unwrap_or(0.0);
529 (x, y)
530 })
531 .collect();
532
533 points.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
535
536 let ref_x = reference.first().copied().unwrap_or(0.0);
537 let ref_y = reference.get(1).copied().unwrap_or(0.0);
538
539 let mut volume = 0.0f32;
540 let mut prev_y = ref_y;
541
542 for &(x, y) in &points {
543 if y > prev_y {
544 volume += (x - ref_x) * (y - prev_y);
545 prev_y = y;
546 }
547 }
548
549 volume
550 }
551
552 pub fn len(&self) -> usize {
554 self.front.len()
555 }
556
557 pub fn is_empty(&self) -> bool {
559 self.front.is_empty()
560 }
561
562 pub fn best_on(&self, objective_index: usize) -> Option<&ParetoPoint> {
564 self.front
565 .iter()
566 .max_by(|a, b| {
567 let va = a.objectives.get(objective_index).copied().unwrap_or(0.0);
568 let vb = b.objectives.get(objective_index).copied().unwrap_or(0.0);
569 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
570 })
571 }
572
573 pub fn spread(&self) -> Vec<f32> {
575 if self.front.is_empty() {
576 return Vec::new();
577 }
578 let dim = self.front[0].objectives.len();
579 (0..dim)
580 .map(|i| {
581 let vals: Vec<f32> = self.front.iter().map(|p| p.objectives[i]).collect();
582 let min = vals.iter().cloned().fold(f32::INFINITY, f32::min);
583 let max = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
584 max - min
585 })
586 .collect()
587 }
588}
589
590#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct CuriosityBonus {
601 visit_counts: HashMap<ContextBucket, HashMap<ArmId, u64>>,
603 pub total_visits: u64,
605 pub exploration_coeff: f32,
607}
608
609impl CuriosityBonus {
610 pub fn new(exploration_coeff: f32) -> Self {
613 Self {
614 visit_counts: HashMap::new(),
615 total_visits: 0,
616 exploration_coeff: exploration_coeff.max(0.0),
617 }
618 }
619
620 pub fn record_visit(&mut self, bucket: &ContextBucket, arm: &ArmId) {
622 if let Some(arms) = self.visit_counts.get_mut(bucket) {
624 if let Some(count) = arms.get_mut(arm) {
625 *count += 1;
626 } else {
627 arms.insert(arm.clone(), 1);
628 }
629 } else {
630 let mut arms = HashMap::new();
631 arms.insert(arm.clone(), 1u64);
632 self.visit_counts.insert(bucket.clone(), arms);
633 }
634 self.total_visits += 1;
635 }
636
637 pub fn bonus(&self, bucket: &ContextBucket, arm: &ArmId) -> f32 {
644 if self.total_visits < 2 {
645 return self.exploration_coeff; }
647
648 let arm_visits = self
649 .visit_counts
650 .get(bucket)
651 .and_then(|arms| arms.get(arm))
652 .copied()
653 .unwrap_or(0);
654
655 if arm_visits == 0 {
656 return self.exploration_coeff * 2.0; }
658
659 let log_n = (self.total_visits as f32).ln();
660 self.exploration_coeff * (log_n / arm_visits as f32).sqrt()
661 }
662
663 pub fn most_curious_bucket(&self) -> Option<&ContextBucket> {
665 let mut min_visits = u64::MAX;
667 let mut most_curious = None;
668
669 for (bucket, arms) in &self.visit_counts {
670 let total: u64 = arms.values().sum();
671 if total < min_visits {
672 min_visits = total;
673 most_curious = Some(bucket);
674 }
675 }
676
677 most_curious
678 }
679
680 pub fn novelty_score(&self, bucket: &ContextBucket) -> f32 {
683 if self.total_visits == 0 {
684 return 1.0;
685 }
686
687 let bucket_visits: u64 = self
688 .visit_counts
689 .get(bucket)
690 .map(|arms| arms.values().sum())
691 .unwrap_or(0);
692
693 if bucket_visits == 0 {
694 return 1.0;
695 }
696
697 1.0 - (bucket_visits as f32 / self.total_visits as f32)
698 }
699}
700
701#[derive(Debug, Clone, Serialize, Deserialize)]
711pub struct MetaLearningEngine {
712 pub regret: RegretTracker,
713 pub plateau: PlateauDetector,
714 pub pareto: ParetoFront,
715 pub curiosity: CuriosityBonus,
716 pub decaying_betas: HashMap<(ContextBucket, ArmId), DecayingBeta>,
718 decay_factor: f32,
720}
721
722impl MetaLearningEngine {
723 pub fn new() -> Self {
725 Self {
726 regret: RegretTracker::new(50),
727 plateau: PlateauDetector::new(5, 0.005),
728 pareto: ParetoFront::new(),
729 curiosity: CuriosityBonus::new(1.41),
730 decaying_betas: HashMap::new(),
731 decay_factor: 0.995,
732 }
733 }
734
735 pub fn with_config(
737 regret_snapshot_interval: u64,
738 plateau_window: usize,
739 plateau_threshold: f32,
740 exploration_coeff: f32,
741 decay_factor: f32,
742 ) -> Self {
743 Self {
744 regret: RegretTracker::new(regret_snapshot_interval),
745 plateau: PlateauDetector::new(plateau_window, plateau_threshold),
746 pareto: ParetoFront::new(),
747 curiosity: CuriosityBonus::new(exploration_coeff),
748 decaying_betas: HashMap::new(),
749 decay_factor,
750 }
751 }
752
753 pub fn record_decision(
755 &mut self,
756 bucket: &ContextBucket,
757 arm: &ArmId,
758 reward: f32,
759 ) {
760 self.regret.record(bucket, arm, reward);
762
763 self.curiosity.record_visit(bucket, arm);
765
766 let key = (bucket.clone(), arm.clone());
769 if let Some(db) = self.decaying_betas.get_mut(&key) {
770 db.update(reward);
771 } else {
772 let mut db = DecayingBeta::new(self.decay_factor);
773 db.update(reward);
774 self.decaying_betas.insert(key, db);
775 }
776 }
777
778 pub fn record_kernel(
780 &mut self,
781 kernel_id: &str,
782 accuracy: f32,
783 cost: f32,
784 robustness: f32,
785 generation: u32,
786 ) {
787 let point = ParetoPoint {
788 kernel_id: kernel_id.to_string(),
789 objectives: vec![accuracy, -cost, robustness],
791 generation,
792 };
793 self.pareto.insert(point);
794 }
795
796 pub fn check_plateau(&mut self, points: &[CostCurvePoint]) -> PlateauAction {
798 self.plateau.check(points)
799 }
800
801 pub fn boosted_score(
805 &self,
806 bucket: &ContextBucket,
807 arm: &ArmId,
808 thompson_sample: f32,
809 ) -> f32 {
810 let bonus = self.curiosity.bonus(bucket, arm);
811 thompson_sample + bonus
812 }
813
814 pub fn decaying_mean(
816 &self,
817 bucket: &ContextBucket,
818 arm: &ArmId,
819 ) -> Option<f32> {
820 let key = (bucket.clone(), arm.clone());
821 self.decaying_betas.get(&key).map(|db| db.mean())
822 }
823
824 pub fn health_check(&self) -> MetaLearningHealth {
826 let regret_summary = self.regret.summary();
827 let pareto_size = self.pareto.len();
828
829 let is_learning = regret_summary.mean_growth_rate < 0.8;
830 let is_diverse = pareto_size >= 3;
831 let is_exploring = self.curiosity.total_visits > 0;
832
833 MetaLearningHealth {
834 regret: regret_summary,
835 pareto_size,
836 pareto_hypervolume: self.pareto.hypervolume(&[0.0, -1.0, 0.0]),
837 consecutive_plateaus: self.plateau.consecutive_plateaus,
838 total_plateaus: self.plateau.total_plateaus,
839 curiosity_total_visits: self.curiosity.total_visits,
840 decaying_beta_count: self.decaying_betas.len(),
841 is_learning,
842 is_diverse,
843 is_exploring,
844 }
845 }
846}
847
848impl Default for MetaLearningEngine {
849 fn default() -> Self {
850 Self::new()
851 }
852}
853
854#[derive(Debug, Clone, Serialize, Deserialize)]
856pub struct MetaLearningHealth {
857 pub regret: RegretSummary,
858 pub pareto_size: usize,
859 pub pareto_hypervolume: f32,
860 pub consecutive_plateaus: u32,
861 pub total_plateaus: u32,
862 pub curiosity_total_visits: u64,
863 pub decaying_beta_count: usize,
864 pub is_learning: bool,
866 pub is_diverse: bool,
868 pub is_exploring: bool,
870}
871
872#[cfg(test)]
877mod tests {
878 use super::*;
879
880 fn test_bucket(tier: &str, cat: &str) -> ContextBucket {
881 ContextBucket {
882 difficulty_tier: tier.into(),
883 category: cat.into(),
884 }
885 }
886
887 #[test]
890 fn test_regret_tracker_empty() {
891 let tracker = RegretTracker::new(10);
892 assert_eq!(tracker.total_regret, 0.0);
893 assert_eq!(tracker.average_regret(), 0.0);
894 }
895
896 #[test]
897 fn test_regret_tracker_optimal_arm() {
898 let mut tracker = RegretTracker::new(10);
899 let bucket = test_bucket("easy", "test");
900 let arm = ArmId("best".into());
901
902 for _ in 0..100 {
904 tracker.record(&bucket, &arm, 0.9);
905 }
906
907 assert_eq!(tracker.total_observations, 100);
908 assert!(tracker.total_regret < 1e-6);
910 }
911
912 #[test]
913 fn test_regret_tracker_suboptimal_arm() {
914 let mut tracker = RegretTracker::new(10);
915 let bucket = test_bucket("medium", "test");
916 let good = ArmId("good".into());
917 let bad = ArmId("bad".into());
918
919 for _ in 0..50 {
921 tracker.record(&bucket, &good, 0.9);
922 }
923
924 for _ in 0..50 {
926 tracker.record(&bucket, &bad, 0.3);
927 }
928
929 assert!(tracker.total_regret > 0.0);
930 assert!(tracker.average_regret() > 0.0);
931 }
932
933 #[test]
934 fn test_regret_growth_rate() {
935 let mut tracker = RegretTracker::new(5);
936 let bucket = test_bucket("hard", "test");
937 let arm_a = ArmId("a".into());
938 let arm_b = ArmId("b".into());
939
940 for _ in 0..50 {
941 tracker.record(&bucket, &arm_a, 0.8);
942 }
943 for _ in 0..50 {
944 tracker.record(&bucket, &arm_b, 0.4);
945 }
946
947 let rate = tracker.regret_growth_rate(&bucket);
948 assert!(rate.is_some());
949 }
951
952 #[test]
953 fn test_regret_summary() {
954 let mut tracker = RegretTracker::new(10);
955 let bucket = test_bucket("easy", "algo");
956 let arm = ArmId("test".into());
957
958 for _ in 0..20 {
959 tracker.record(&bucket, &arm, 0.7);
960 }
961
962 let summary = tracker.summary();
963 assert_eq!(summary.total_observations, 20);
964 assert_eq!(summary.bucket_count, 1);
965 }
966
967 #[test]
970 fn test_decaying_beta_initial() {
971 let db = DecayingBeta::new(0.995);
972 assert!((db.mean() - 0.5).abs() < 1e-6); assert_eq!(db.effective_n, 0.0);
974 }
975
976 #[test]
977 fn test_decaying_beta_update() {
978 let mut db = DecayingBeta::new(0.995);
979
980 for _ in 0..100 {
981 db.update(0.9); }
983
984 assert!(db.mean() > 0.7); assert!(db.effective_n > 50.0); }
987
988 #[test]
989 fn test_decaying_beta_adapts() {
990 let mut db = DecayingBeta::new(0.99); for _ in 0..100 {
994 db.update(0.95);
995 }
996 let mean_after_good = db.mean();
997 assert!(mean_after_good > 0.8);
998
999 for _ in 0..100 {
1001 db.update(0.1);
1002 }
1003 let mean_after_bad = db.mean();
1004
1005 assert!(mean_after_bad < mean_after_good);
1007 assert!(mean_after_bad < 0.5); }
1009
1010 #[test]
1011 fn test_decaying_beta_window() {
1012 let db = DecayingBeta::new(0.99);
1013 let window = db.effective_window();
1014 assert!((window - 100.0).abs() < 1.0); let db2 = DecayingBeta::new(0.995);
1017 let window2 = db2.effective_window();
1018 assert!((window2 - 200.0).abs() < 1.0); }
1020
1021 #[test]
1022 fn test_decaying_to_standard() {
1023 let mut db = DecayingBeta::new(0.995);
1024 for _ in 0..10 {
1025 db.update(0.8);
1026 }
1027 let params = db.to_beta_params();
1028 assert!(params.alpha > 1.0);
1029 assert!(params.beta > 1.0);
1030 assert!((params.mean() - db.mean()).abs() < 1e-6);
1031 }
1032
1033 #[test]
1036 fn test_plateau_no_data() {
1037 let mut detector = PlateauDetector::new(3, 0.01);
1038 let action = detector.check(&[]);
1039 assert_eq!(action, PlateauAction::Continue);
1040 }
1041
1042 #[test]
1043 fn test_plateau_not_enough_data() {
1044 let mut detector = PlateauDetector::new(3, 0.01);
1045 let points: Vec<CostCurvePoint> = (0..4)
1046 .map(|i| CostCurvePoint {
1047 cycle: i as u64,
1048 accuracy: 0.5 + i as f32 * 0.1,
1049 cost_per_solve: 0.1,
1050 robustness: 0.8,
1051 policy_violations: 0,
1052 timestamp: i as f64,
1053 })
1054 .collect();
1055
1056 let action = detector.check(&points);
1057 assert_eq!(action, PlateauAction::Continue);
1058 }
1059
1060 #[test]
1061 fn test_plateau_detected() {
1062 let mut detector = PlateauDetector::new(3, 0.01);
1063
1064 let points: Vec<CostCurvePoint> = (0..6)
1066 .map(|i| CostCurvePoint {
1067 cycle: i as u64,
1068 accuracy: 0.80 + (i as f32 * 0.001), cost_per_solve: 0.1,
1070 robustness: 0.8,
1071 policy_violations: 0,
1072 timestamp: i as f64,
1073 })
1074 .collect();
1075
1076 let action = detector.check(&points);
1077 assert_ne!(action, PlateauAction::Continue);
1078 }
1079
1080 #[test]
1081 fn test_plateau_improving() {
1082 let mut detector = PlateauDetector::new(3, 0.01);
1083
1084 let points: Vec<CostCurvePoint> = (0..6)
1086 .map(|i| CostCurvePoint {
1087 cycle: i as u64,
1088 accuracy: 0.50 + i as f32 * 0.08, cost_per_solve: 0.1,
1090 robustness: 0.8,
1091 policy_violations: 0,
1092 timestamp: i as f64,
1093 })
1094 .collect();
1095
1096 let action = detector.check(&points);
1097 assert_eq!(action, PlateauAction::Continue);
1098 }
1099
1100 #[test]
1101 fn test_plateau_escalation() {
1102 let mut detector = PlateauDetector::new(3, 0.01);
1103
1104 let flat_points: Vec<CostCurvePoint> = (0..6)
1105 .map(|i| CostCurvePoint {
1106 cycle: i as u64,
1107 accuracy: 0.80,
1108 cost_per_solve: 0.1,
1109 robustness: 0.8,
1110 policy_violations: 0,
1111 timestamp: i as f64,
1112 })
1113 .collect();
1114
1115 assert_eq!(detector.check(&flat_points), PlateauAction::IncreaseExploration);
1116 assert_eq!(detector.check(&flat_points), PlateauAction::TriggerTransfer);
1117 assert_eq!(detector.check(&flat_points), PlateauAction::TriggerTransfer);
1118 assert_eq!(detector.check(&flat_points), PlateauAction::InjectDiversity);
1119 }
1120
1121 #[test]
1122 fn test_learning_velocity() {
1123 let detector = PlateauDetector::new(3, 0.01);
1124
1125 let points: Vec<CostCurvePoint> = (0..6)
1126 .map(|i| CostCurvePoint {
1127 cycle: i as u64,
1128 accuracy: 0.50 + i as f32 * 0.1,
1129 cost_per_solve: 0.1,
1130 robustness: 0.8,
1131 policy_violations: 0,
1132 timestamp: i as f64,
1133 })
1134 .collect();
1135
1136 let velocity = detector.learning_velocity(&points);
1137 assert!(velocity > 0.0); }
1139
1140 #[test]
1143 fn test_pareto_dominates() {
1144 assert!(ParetoFront::dominates(&[0.9, -0.1, 0.8], &[0.8, -0.2, 0.7]));
1145 assert!(!ParetoFront::dominates(&[0.9, -0.3, 0.8], &[0.8, -0.1, 0.7]));
1146 assert!(!ParetoFront::dominates(&[0.9, -0.1, 0.8], &[0.9, -0.1, 0.8])); }
1148
1149 #[test]
1150 fn test_pareto_insert_non_dominated() {
1151 let mut front = ParetoFront::new();
1152
1153 assert!(front.insert(ParetoPoint {
1155 kernel_id: "a".into(),
1156 objectives: vec![0.9, -0.3, 0.7],
1157 generation: 0,
1158 }));
1159 assert!(front.insert(ParetoPoint {
1160 kernel_id: "b".into(),
1161 objectives: vec![0.7, -0.1, 0.9],
1162 generation: 0,
1163 }));
1164
1165 assert_eq!(front.len(), 2);
1166 }
1167
1168 #[test]
1169 fn test_pareto_insert_dominated() {
1170 let mut front = ParetoFront::new();
1171
1172 front.insert(ParetoPoint {
1173 kernel_id: "good".into(),
1174 objectives: vec![0.9, -0.1, 0.9],
1175 generation: 0,
1176 });
1177
1178 let added = front.insert(ParetoPoint {
1180 kernel_id: "bad".into(),
1181 objectives: vec![0.5, -0.5, 0.5],
1182 generation: 0,
1183 });
1184
1185 assert!(!added);
1186 assert_eq!(front.len(), 1);
1187 }
1188
1189 #[test]
1190 fn test_pareto_removes_dominated() {
1191 let mut front = ParetoFront::new();
1192
1193 front.insert(ParetoPoint {
1194 kernel_id: "old".into(),
1195 objectives: vec![0.5, -0.3, 0.5],
1196 generation: 0,
1197 });
1198
1199 front.insert(ParetoPoint {
1201 kernel_id: "new".into(),
1202 objectives: vec![0.9, -0.1, 0.9],
1203 generation: 1,
1204 });
1205
1206 assert_eq!(front.len(), 1);
1207 assert_eq!(front.front[0].kernel_id, "new");
1208 }
1209
1210 #[test]
1211 fn test_pareto_best_on_objective() {
1212 let mut front = ParetoFront::new();
1213
1214 front.insert(ParetoPoint {
1215 kernel_id: "accurate".into(),
1216 objectives: vec![0.95, -0.5, 0.6],
1217 generation: 0,
1218 });
1219 front.insert(ParetoPoint {
1220 kernel_id: "cheap".into(),
1221 objectives: vec![0.7, -0.05, 0.7],
1222 generation: 0,
1223 });
1224 front.insert(ParetoPoint {
1225 kernel_id: "robust".into(),
1226 objectives: vec![0.8, -0.3, 0.95],
1227 generation: 0,
1228 });
1229
1230 assert_eq!(front.best_on(0).unwrap().kernel_id, "accurate");
1231 assert_eq!(front.best_on(1).unwrap().kernel_id, "cheap"); assert_eq!(front.best_on(2).unwrap().kernel_id, "robust");
1233 }
1234
1235 #[test]
1236 fn test_pareto_spread() {
1237 let mut front = ParetoFront::new();
1238
1239 front.insert(ParetoPoint {
1241 kernel_id: "a".into(),
1242 objectives: vec![0.9, -0.5],
1243 generation: 0,
1244 });
1245 front.insert(ParetoPoint {
1246 kernel_id: "b".into(),
1247 objectives: vec![0.5, -0.1],
1248 generation: 0,
1249 });
1250
1251 assert_eq!(front.len(), 2); let spread = front.spread();
1253 assert_eq!(spread.len(), 2);
1254 assert!((spread[0] - 0.4).abs() < 1e-4); assert!((spread[1] - 0.4).abs() < 1e-4); }
1257
1258 #[test]
1259 fn test_pareto_hypervolume_2d() {
1260 let mut front = ParetoFront::new();
1261
1262 front.insert(ParetoPoint {
1263 kernel_id: "a".into(),
1264 objectives: vec![1.0, 1.0],
1265 generation: 0,
1266 });
1267
1268 let hv = front.hypervolume(&[0.0, 0.0]);
1269 assert!((hv - 1.0).abs() < 1e-4); }
1271
1272 #[test]
1275 fn test_curiosity_bonus_unvisited() {
1276 let curiosity = CuriosityBonus::new(1.41);
1277 let bucket = test_bucket("hard", "novel");
1278 let arm = ArmId("new".into());
1279
1280 let bonus = curiosity.bonus(&bucket, &arm);
1281 assert!(bonus > 0.0); }
1283
1284 #[test]
1285 fn test_curiosity_bonus_decays_with_visits() {
1286 let mut curiosity = CuriosityBonus::new(1.41);
1287 let bucket = test_bucket("easy", "test");
1288 let arm = ArmId("a".into());
1289
1290 let bonus_before = curiosity.bonus(&bucket, &arm);
1291
1292 for _ in 0..50 {
1293 curiosity.record_visit(&bucket, &arm);
1294 }
1295
1296 let bonus_after = curiosity.bonus(&bucket, &arm);
1297 assert!(bonus_after < bonus_before); }
1299
1300 #[test]
1301 fn test_curiosity_novelty_score() {
1302 let mut curiosity = CuriosityBonus::new(1.41);
1303 let explored = test_bucket("easy", "common");
1304 let novel = test_bucket("hard", "rare");
1305 let arm = ArmId("a".into());
1306
1307 for _ in 0..100 {
1308 curiosity.record_visit(&explored, &arm);
1309 }
1310 curiosity.record_visit(&novel, &arm);
1311
1312 let explored_novelty = curiosity.novelty_score(&explored);
1313 let novel_novelty = curiosity.novelty_score(&novel);
1314
1315 assert!(novel_novelty > explored_novelty);
1316 }
1317
1318 #[test]
1321 fn test_meta_engine_creation() {
1322 let engine = MetaLearningEngine::new();
1323 assert_eq!(engine.regret.total_observations, 0);
1324 assert!(engine.pareto.is_empty());
1325 assert_eq!(engine.curiosity.total_visits, 0);
1326 }
1327
1328 #[test]
1329 fn test_meta_engine_record_decision() {
1330 let mut engine = MetaLearningEngine::new();
1331 let bucket = test_bucket("medium", "algo");
1332 let arm = ArmId("greedy".into());
1333
1334 for _ in 0..50 {
1335 engine.record_decision(&bucket, &arm, 0.85);
1336 }
1337
1338 assert_eq!(engine.regret.total_observations, 50);
1339 assert_eq!(engine.curiosity.total_visits, 50);
1340 assert!(engine.decaying_mean(&bucket, &arm).unwrap() > 0.7);
1341 }
1342
1343 #[test]
1344 fn test_meta_engine_boosted_score() {
1345 let mut engine = MetaLearningEngine::new();
1346 let explored = test_bucket("easy", "common");
1347 let novel = test_bucket("hard", "rare");
1348 let arm = ArmId("a".into());
1349
1350 for _ in 0..100 {
1352 engine.record_decision(&explored, &arm, 0.8);
1353 }
1354
1355 let score_explored = engine.boosted_score(&explored, &arm, 0.5);
1356 let score_novel = engine.boosted_score(&novel, &arm, 0.5);
1357
1358 assert!(score_novel > score_explored);
1360 }
1361
1362 #[test]
1363 fn test_meta_engine_kernel_recording() {
1364 let mut engine = MetaLearningEngine::new();
1365
1366 engine.record_kernel("k1", 0.9, 0.3, 0.7, 0);
1367 engine.record_kernel("k2", 0.7, 0.1, 0.9, 0);
1368 engine.record_kernel("k3", 0.5, 0.5, 0.5, 0); assert!(engine.pareto.len() <= 2);
1372 }
1373
1374 #[test]
1375 fn test_meta_engine_health_check() {
1376 let mut engine = MetaLearningEngine::new();
1377 let bucket = test_bucket("medium", "test");
1378 let arm = ArmId("a".into());
1379
1380 for _ in 0..100 {
1381 engine.record_decision(&bucket, &arm, 0.8);
1382 }
1383
1384 let health = engine.health_check();
1385 assert_eq!(health.curiosity_total_visits, 100);
1386 assert!(health.is_exploring);
1387 }
1388
1389 #[test]
1390 fn test_meta_engine_plateau_check() {
1391 let mut engine = MetaLearningEngine::new();
1392
1393 let flat_points: Vec<CostCurvePoint> = (0..10)
1394 .map(|i| CostCurvePoint {
1395 cycle: i as u64,
1396 accuracy: 0.80,
1397 cost_per_solve: 0.1,
1398 robustness: 0.8,
1399 policy_violations: 0,
1400 timestamp: i as f64,
1401 })
1402 .collect();
1403
1404 let action = engine.check_plateau(&flat_points);
1405 assert_ne!(action, PlateauAction::Continue);
1406 }
1407
1408 #[test]
1409 fn test_meta_engine_default() {
1410 let engine = MetaLearningEngine::default();
1411 assert_eq!(engine.curiosity.exploration_coeff, 1.41);
1412 }
1413}