Skip to main content

ruvector_domain_expansion/
meta_learning.rs

1//! Meta-Learning Improvements for AGI Learning Architecture
2//!
3//! Five composable enhancements that layer on top of the existing
4//! Thompson Sampling + Population Search + Cost Curve pipeline:
5//!
6//! 1. **RegretTracker**: Measures optimality gap — cumulative difference
7//!    between chosen arms and the best-known arm. You can't improve
8//!    what you don't measure.
9//!
10//! 2. **DecayingBeta**: Beta distribution with exponential forgetting.
11//!    Old evidence decays so the system adapts to non-stationary
12//!    environments instead of calcifying on stale data.
13//!
14//! 3. **PlateauDetector**: Detects when learning has stalled by comparing
15//!    recent accuracy windows. Triggers strategy changes: more exploration,
16//!    cross-domain transfer, or population diversity injection.
17//!
18//! 4. **ParetoFront**: Multi-objective optimization tracking. Instead of
19//!    collapsing accuracy/cost/robustness into one scalar, tracks the full
20//!    Pareto front of non-dominated solutions.
21//!
22//! 5. **CuriosityBonus**: UCB-style exploration bonus for under-visited
23//!    context buckets. Directs exploration toward novel contexts rather
24//!    than relying solely on Thompson Sampling's implicit exploration.
25
26use crate::cost_curve::CostCurvePoint;
27use crate::transfer::{ArmId, BetaParams, ContextBucket};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30
31// ═══════════════════════════════════════════════════════════════════
32// 1. Regret Tracker
33// ═══════════════════════════════════════════════════════════════════
34
35/// Per-bucket regret state: tracks best arm and cumulative regret.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct BucketRegret {
38    /// Best known arm mean reward.
39    pub best_mean: f32,
40    /// Which arm is currently best.
41    pub best_arm: ArmId,
42    /// Cumulative regret: Σ(best_reward - chosen_reward).
43    pub cumulative_regret: f64,
44    /// Total observations in this bucket.
45    pub observations: u64,
46    /// Per-cycle regret snapshots for trend analysis.
47    pub regret_history: Vec<f64>,
48    /// Per-arm running mean for best-arm tracking.
49    arm_means: HashMap<ArmId, (f64, u64)>,
50}
51
52impl BucketRegret {
53    fn new() -> Self {
54        Self {
55            best_mean: 0.0,
56            best_arm: ArmId("unknown".into()),
57            cumulative_regret: 0.0,
58            observations: 0,
59            regret_history: Vec::new(),
60            arm_means: HashMap::new(),
61        }
62    }
63}
64
65/// Tracks cumulative regret across all context buckets.
66///
67/// Regret = Σ(best_arm_mean - chosen_arm_reward) over time.
68/// Sublinear regret growth (O(√T)) indicates the system is learning.
69/// Linear regret means it's not adapting.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct RegretTracker {
72    buckets: HashMap<ContextBucket, BucketRegret>,
73    /// Global cumulative regret across all buckets.
74    pub total_regret: f64,
75    /// Total observations across all buckets.
76    pub total_observations: u64,
77    /// Snapshot interval: take regret snapshot every N observations.
78    snapshot_interval: u64,
79}
80
81impl RegretTracker {
82    /// Create a new regret tracker.
83    pub fn new(snapshot_interval: u64) -> Self {
84        Self {
85            buckets: HashMap::new(),
86            total_regret: 0.0,
87            total_observations: 0,
88            snapshot_interval: snapshot_interval.max(1),
89        }
90    }
91
92    /// Record a choice and its reward, updating regret.
93    pub fn record(
94        &mut self,
95        bucket: &ContextBucket,
96        arm: &ArmId,
97        reward: f32,
98    ) {
99        // Avoid cloning when entry already exists (hot path optimization).
100        if !self.buckets.contains_key(bucket) {
101            self.buckets.insert(bucket.clone(), BucketRegret::new());
102        }
103        let entry = self.buckets.get_mut(bucket).unwrap();
104
105        // Update arm running mean (avoid clone when arm exists).
106        if !entry.arm_means.contains_key(arm) {
107            entry.arm_means.insert(arm.clone(), (0.0, 0));
108        }
109        let (sum, count) = entry.arm_means.get_mut(arm).unwrap();
110        *sum += reward as f64;
111        *count += 1;
112        let arm_mean = *sum / *count as f64;
113
114        // Update best arm if this arm's mean exceeds current best
115        if arm_mean > entry.best_mean as f64 {
116            entry.best_mean = arm_mean as f32;
117            entry.best_arm = arm.clone();
118        }
119
120        // Instantaneous regret: best_mean - observed_reward
121        let instant_regret = (entry.best_mean as f64 - reward as f64).max(0.0);
122        entry.cumulative_regret += instant_regret;
123        entry.observations += 1;
124        self.total_regret += instant_regret;
125        self.total_observations += 1;
126
127        // Snapshot on interval
128        if entry.observations % self.snapshot_interval == 0 {
129            entry.regret_history.push(entry.cumulative_regret);
130        }
131    }
132
133    /// Regret growth rate for a bucket. Sublinear (< 1.0) means learning.
134    ///
135    /// Computed as: log(regret) / log(observations).
136    /// Perfect learning → 0.5 (O(√T)). No learning → 1.0 (O(T)).
137    pub fn regret_growth_rate(&self, bucket: &ContextBucket) -> Option<f32> {
138        let entry = self.buckets.get(bucket)?;
139        if entry.observations < 10 || entry.cumulative_regret < 1e-10 {
140            return None;
141        }
142        let log_regret = (entry.cumulative_regret).ln();
143        let log_t = (entry.observations as f64).ln();
144        Some((log_regret / log_t) as f32)
145    }
146
147    /// Average regret per observation (lower = better).
148    pub fn average_regret(&self) -> f64 {
149        if self.total_observations == 0 {
150            return 0.0;
151        }
152        self.total_regret / self.total_observations as f64
153    }
154
155    /// Check if learning has converged: regret growth rate < threshold.
156    pub fn has_converged(&self, bucket: &ContextBucket, threshold: f32) -> bool {
157        self.regret_growth_rate(bucket)
158            .map_or(false, |rate| rate < threshold)
159    }
160
161    /// Get regret summary for all buckets.
162    pub fn summary(&self) -> RegretSummary {
163        let bucket_rates: Vec<(ContextBucket, f32)> = self
164            .buckets
165            .keys()
166            .filter_map(|b| self.regret_growth_rate(b).map(|r| (b.clone(), r)))
167            .collect();
168
169        let mean_rate = if bucket_rates.is_empty() {
170            1.0
171        } else {
172            bucket_rates.iter().map(|(_, r)| r).sum::<f32>() / bucket_rates.len() as f32
173        };
174
175        RegretSummary {
176            total_regret: self.total_regret,
177            total_observations: self.total_observations,
178            average_regret: self.average_regret(),
179            mean_growth_rate: mean_rate,
180            bucket_count: self.buckets.len(),
181            converged_buckets: bucket_rates.iter().filter(|(_, r)| *r < 0.7).count(),
182        }
183    }
184}
185
186/// Summary of regret across all buckets.
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct RegretSummary {
189    pub total_regret: f64,
190    pub total_observations: u64,
191    pub average_regret: f64,
192    /// Mean regret growth rate across buckets. < 0.7 = sublinear = learning.
193    pub mean_growth_rate: f32,
194    pub bucket_count: usize,
195    /// Buckets where regret growth is sublinear (learning converged).
196    pub converged_buckets: usize,
197}
198
199// ═══════════════════════════════════════════════════════════════════
200// 2. Decaying Beta Distribution
201// ═══════════════════════════════════════════════════════════════════
202
203/// Beta distribution with exponential forgetting for non-stationary environments.
204///
205/// On each update, old evidence decays by `decay_factor` before the new
206/// observation is added. This gives recent evidence more weight while
207/// gradually forgetting stale data.
208///
209/// Effective window size ≈ 1 / (1 - decay_factor).
210/// decay_factor = 0.995 → window ≈ 200 observations.
211/// decay_factor = 0.99  → window ≈ 100 observations.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct DecayingBeta {
214    pub alpha: f32,
215    pub beta: f32,
216    /// Decay factor per observation. 1.0 = no decay (standard Beta).
217    pub decay_factor: f32,
218    /// Effective sample size (decayed observation count).
219    pub effective_n: f32,
220}
221
222impl DecayingBeta {
223    /// Create with uniform prior and specified decay.
224    pub fn new(decay_factor: f32) -> Self {
225        Self {
226            alpha: 1.0,
227            beta: 1.0,
228            decay_factor: decay_factor.clamp(0.9, 1.0),
229            effective_n: 0.0,
230        }
231    }
232
233    /// Create from an existing BetaParams with decay.
234    pub fn from_beta(params: &BetaParams, decay_factor: f32) -> Self {
235        Self {
236            alpha: params.alpha,
237            beta: params.beta,
238            decay_factor: decay_factor.clamp(0.9, 1.0),
239            effective_n: params.alpha + params.beta - 2.0,
240        }
241    }
242
243    /// Update with exponential decay: old evidence shrinks, new evidence adds.
244    pub fn update(&mut self, reward: f32) {
245        // Decay existing evidence toward the prior
246        self.alpha = 1.0 + (self.alpha - 1.0) * self.decay_factor;
247        self.beta = 1.0 + (self.beta - 1.0) * self.decay_factor;
248
249        // Add new observation
250        self.alpha += reward;
251        self.beta += 1.0 - reward;
252
253        // Track effective sample size
254        self.effective_n = self.effective_n * self.decay_factor + 1.0;
255    }
256
257    /// Mean of the distribution.
258    pub fn mean(&self) -> f32 {
259        self.alpha / (self.alpha + self.beta)
260    }
261
262    /// Variance of the distribution.
263    pub fn variance(&self) -> f32 {
264        let total = self.alpha + self.beta;
265        (self.alpha * self.beta) / (total * total * (total + 1.0))
266    }
267
268    /// Convert back to standard BetaParams (snapshot).
269    pub fn to_beta_params(&self) -> BetaParams {
270        BetaParams {
271            alpha: self.alpha,
272            beta: self.beta,
273        }
274    }
275
276    /// Effective window size: how many recent observations dominate.
277    pub fn effective_window(&self) -> f32 {
278        if self.decay_factor >= 1.0 {
279            self.effective_n
280        } else {
281            1.0 / (1.0 - self.decay_factor)
282        }
283    }
284}
285
286// ═══════════════════════════════════════════════════════════════════
287// 3. Plateau Detector
288// ═══════════════════════════════════════════════════════════════════
289
290/// Detects when learning has stalled by comparing accuracy windows.
291///
292/// Compares the mean accuracy of the most recent `window_size` points
293/// against the prior window. If improvement is below threshold,
294/// learning has plateaued.
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct PlateauDetector {
297    /// Number of points per comparison window.
298    pub window_size: usize,
299    /// Minimum improvement to not be considered a plateau.
300    pub improvement_threshold: f32,
301    /// Number of consecutive plateau detections.
302    pub consecutive_plateaus: u32,
303    /// Total plateaus detected.
304    pub total_plateaus: u32,
305}
306
307/// What to do when a plateau is detected.
308#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
309pub enum PlateauAction {
310    /// Continue learning, no plateau detected.
311    Continue,
312    /// Mild plateau: increase exploration budget.
313    IncreaseExploration,
314    /// Moderate plateau: trigger cross-domain transfer.
315    TriggerTransfer,
316    /// Severe plateau: inject diversity into population.
317    InjectDiversity,
318    /// Extreme plateau: reset and restart with new strategy.
319    Reset,
320}
321
322impl PlateauDetector {
323    /// Create a new plateau detector.
324    pub fn new(window_size: usize, improvement_threshold: f32) -> Self {
325        Self {
326            window_size: window_size.max(3),
327            improvement_threshold: improvement_threshold.max(0.001),
328            consecutive_plateaus: 0,
329            total_plateaus: 0,
330        }
331    }
332
333    /// Check if learning has plateaued and recommend an action.
334    pub fn check(&mut self, points: &[CostCurvePoint]) -> PlateauAction {
335        if points.len() < self.window_size * 2 {
336            self.consecutive_plateaus = 0;
337            return PlateauAction::Continue;
338        }
339
340        let n = points.len();
341        let recent = &points[n - self.window_size..];
342        let prior = &points[n - 2 * self.window_size..n - self.window_size];
343
344        let recent_mean = recent.iter().map(|p| p.accuracy).sum::<f32>()
345            / recent.len() as f32;
346        let prior_mean = prior.iter().map(|p| p.accuracy).sum::<f32>()
347            / prior.len() as f32;
348
349        let improvement = recent_mean - prior_mean;
350
351        if improvement.abs() < self.improvement_threshold {
352            self.consecutive_plateaus += 1;
353            self.total_plateaus += 1;
354
355            match self.consecutive_plateaus {
356                1 => PlateauAction::IncreaseExploration,
357                2..=3 => PlateauAction::TriggerTransfer,
358                4..=6 => PlateauAction::InjectDiversity,
359                _ => PlateauAction::Reset,
360            }
361        } else {
362            self.consecutive_plateaus = 0;
363            PlateauAction::Continue
364        }
365    }
366
367    /// Check if cost has plateaued (not just accuracy).
368    pub fn check_cost(&self, points: &[CostCurvePoint]) -> bool {
369        if points.len() < self.window_size * 2 {
370            return false;
371        }
372
373        let n = points.len();
374        let recent = &points[n - self.window_size..];
375        let prior = &points[n - 2 * self.window_size..n - self.window_size];
376
377        let recent_cost = recent.iter().map(|p| p.cost_per_solve).sum::<f32>()
378            / recent.len() as f32;
379        let prior_cost = prior.iter().map(|p| p.cost_per_solve).sum::<f32>()
380            / prior.len() as f32;
381
382        // Cost should be decreasing; if it's not, that's a plateau
383        (prior_cost - recent_cost).abs() < self.improvement_threshold
384    }
385
386    /// Compute learning velocity: rate of accuracy change per cycle.
387    pub fn learning_velocity(&self, points: &[CostCurvePoint]) -> f32 {
388        if points.len() < 2 {
389            return 0.0;
390        }
391        let n = points.len();
392        let window = self.window_size.min(n);
393        let recent = &points[n - window..];
394
395        if recent.len() < 2 {
396            return 0.0;
397        }
398
399        let first = recent.first().unwrap();
400        let last = recent.last().unwrap();
401        let dt = (last.cycle - first.cycle).max(1) as f32;
402
403        (last.accuracy - first.accuracy) / dt
404    }
405}
406
407// ═══════════════════════════════════════════════════════════════════
408// 4. Pareto Front (Multi-Objective Optimization)
409// ═══════════════════════════════════════════════════════════════════
410
411/// A point in objective space with its kernel identity.
412#[derive(Debug, Clone, Serialize, Deserialize)]
413pub struct ParetoPoint {
414    /// Kernel identifier.
415    pub kernel_id: String,
416    /// Objective values (higher is better for all).
417    /// Convention: [accuracy, -cost, robustness].
418    pub objectives: Vec<f32>,
419    /// Generation when this point was added.
420    pub generation: u32,
421}
422
423/// Multi-objective Pareto front tracker.
424///
425/// Instead of collapsing multiple objectives into one weighted scalar,
426/// tracks the full set of non-dominated solutions. A solution is
427/// non-dominated if no other solution is better on ALL objectives.
428#[derive(Debug, Clone, Default, Serialize, Deserialize)]
429pub struct ParetoFront {
430    /// Current non-dominated solutions.
431    pub front: Vec<ParetoPoint>,
432    /// Total points evaluated.
433    pub evaluated: u64,
434    /// Number of front updates (when a new point enters the front).
435    pub front_updates: u64,
436}
437
438impl ParetoFront {
439    pub fn new() -> Self {
440        Self::default()
441    }
442
443    /// Check if point `a` dominates point `b`.
444    ///
445    /// Dominance: a is at least as good as b on all objectives,
446    /// and strictly better on at least one.
447    pub fn dominates(a: &[f32], b: &[f32]) -> bool {
448        if a.len() != b.len() {
449            return false;
450        }
451        let mut at_least_equal = true;
452        let mut strictly_better = false;
453
454        for (ai, bi) in a.iter().zip(b.iter()) {
455            if ai < bi {
456                at_least_equal = false;
457                break;
458            }
459            if ai > bi {
460                strictly_better = true;
461            }
462        }
463
464        at_least_equal && strictly_better
465    }
466
467    /// Insert a point into the front. Returns true if the point is non-dominated.
468    ///
469    /// Removes any existing points that the new point dominates.
470    pub fn insert(&mut self, point: ParetoPoint) -> bool {
471        self.evaluated += 1;
472
473        // Check if any existing point dominates the new one
474        for existing in &self.front {
475            if Self::dominates(&existing.objectives, &point.objectives) {
476                return false; // Dominated, don't add
477            }
478        }
479
480        // Remove points dominated by the new one
481        self.front
482            .retain(|existing| !Self::dominates(&point.objectives, &existing.objectives));
483
484        self.front.push(point);
485        self.front_updates += 1;
486        true
487    }
488
489    /// Hypervolume indicator: volume of objective space dominated by the front.
490    ///
491    /// Uses a reference point (all zeros) as the origin.
492    /// Higher hypervolume = better front coverage.
493    /// Only exact for 2D; uses approximation for higher dimensions.
494    pub fn hypervolume(&self, reference: &[f32]) -> f32 {
495        if self.front.is_empty() || reference.is_empty() {
496            return 0.0;
497        }
498
499        let dim = reference.len();
500        if dim == 2 {
501            self.hypervolume_2d(reference)
502        } else {
503            // Approximate: sum of per-point dominated rectangles (overcounts overlaps)
504            self.front
505                .iter()
506                .map(|p| {
507                    p.objectives
508                        .iter()
509                        .zip(reference.iter())
510                        .map(|(oi, ri)| (oi - ri).max(0.0))
511                        .product::<f32>()
512                })
513                .sum()
514        }
515    }
516
517    /// Exact 2D hypervolume via sweep line.
518    fn hypervolume_2d(&self, reference: &[f32]) -> f32 {
519        if self.front.is_empty() {
520            return 0.0;
521        }
522
523        let mut points: Vec<(f32, f32)> = self
524            .front
525            .iter()
526            .map(|p| {
527                let x = p.objectives.first().copied().unwrap_or(0.0);
528                let y = p.objectives.get(1).copied().unwrap_or(0.0);
529                (x, y)
530            })
531            .collect();
532
533        // Sort by x descending
534        points.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
535
536        let ref_x = reference.first().copied().unwrap_or(0.0);
537        let ref_y = reference.get(1).copied().unwrap_or(0.0);
538
539        let mut volume = 0.0f32;
540        let mut prev_y = ref_y;
541
542        for &(x, y) in &points {
543            if y > prev_y {
544                volume += (x - ref_x) * (y - prev_y);
545                prev_y = y;
546            }
547        }
548
549        volume
550    }
551
552    /// Size of the Pareto front.
553    pub fn len(&self) -> usize {
554        self.front.len()
555    }
556
557    /// Whether the front is empty.
558    pub fn is_empty(&self) -> bool {
559        self.front.is_empty()
560    }
561
562    /// Get the front point that maximizes a specific objective.
563    pub fn best_on(&self, objective_index: usize) -> Option<&ParetoPoint> {
564        self.front
565            .iter()
566            .max_by(|a, b| {
567                let va = a.objectives.get(objective_index).copied().unwrap_or(0.0);
568                let vb = b.objectives.get(objective_index).copied().unwrap_or(0.0);
569                va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
570            })
571    }
572
573    /// Spread: range on each objective dimension. Higher = more diverse front.
574    pub fn spread(&self) -> Vec<f32> {
575        if self.front.is_empty() {
576            return Vec::new();
577        }
578        let dim = self.front[0].objectives.len();
579        (0..dim)
580            .map(|i| {
581                let vals: Vec<f32> = self.front.iter().map(|p| p.objectives[i]).collect();
582                let min = vals.iter().cloned().fold(f32::INFINITY, f32::min);
583                let max = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
584                max - min
585            })
586            .collect()
587    }
588}
589
590// ═══════════════════════════════════════════════════════════════════
591// 5. Curiosity Bonus (UCB-style exploration)
592// ═══════════════════════════════════════════════════════════════════
593
594/// UCB-style exploration bonus for under-visited context buckets.
595///
596/// Adds sqrt(2 * ln(N) / n_i) bonus to arm selection, where N is total
597/// observations and n_i is observations for this bucket/arm.
598/// This prioritizes under-explored contexts.
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct CuriosityBonus {
601    /// Per-bucket, per-arm visit counts.
602    visit_counts: HashMap<ContextBucket, HashMap<ArmId, u64>>,
603    /// Total visit count across everything.
604    pub total_visits: u64,
605    /// Exploration coefficient (higher = more curious).
606    pub exploration_coeff: f32,
607}
608
609impl CuriosityBonus {
610    /// Create with a given exploration coefficient.
611    /// Standard UCB uses sqrt(2) ≈ 1.41. Higher = more exploration.
612    pub fn new(exploration_coeff: f32) -> Self {
613        Self {
614            visit_counts: HashMap::new(),
615            total_visits: 0,
616            exploration_coeff: exploration_coeff.max(0.0),
617        }
618    }
619
620    /// Record a visit to a bucket/arm.
621    pub fn record_visit(&mut self, bucket: &ContextBucket, arm: &ArmId) {
622        // Hot path: avoid cloning when entries already exist.
623        if let Some(arms) = self.visit_counts.get_mut(bucket) {
624            if let Some(count) = arms.get_mut(arm) {
625                *count += 1;
626            } else {
627                arms.insert(arm.clone(), 1);
628            }
629        } else {
630            let mut arms = HashMap::new();
631            arms.insert(arm.clone(), 1u64);
632            self.visit_counts.insert(bucket.clone(), arms);
633        }
634        self.total_visits += 1;
635    }
636
637    /// Compute the exploration bonus for a bucket/arm combination.
638    ///
639    /// Bonus = c * sqrt(ln(N) / n), where:
640    /// - c is the exploration coefficient
641    /// - N is total visits
642    /// - n is visits to this specific bucket/arm
643    pub fn bonus(&self, bucket: &ContextBucket, arm: &ArmId) -> f32 {
644        if self.total_visits < 2 {
645            return self.exploration_coeff; // Maximum bonus when no data
646        }
647
648        let arm_visits = self
649            .visit_counts
650            .get(bucket)
651            .and_then(|arms| arms.get(arm))
652            .copied()
653            .unwrap_or(0);
654
655        if arm_visits == 0 {
656            return self.exploration_coeff * 2.0; // Never-visited bonus
657        }
658
659        let log_n = (self.total_visits as f32).ln();
660        self.exploration_coeff * (log_n / arm_visits as f32).sqrt()
661    }
662
663    /// Find the most under-explored bucket (lowest total visits).
664    pub fn most_curious_bucket(&self) -> Option<&ContextBucket> {
665        // Find buckets with fewest total visits
666        let mut min_visits = u64::MAX;
667        let mut most_curious = None;
668
669        for (bucket, arms) in &self.visit_counts {
670            let total: u64 = arms.values().sum();
671            if total < min_visits {
672                min_visits = total;
673                most_curious = Some(bucket);
674            }
675        }
676
677        most_curious
678    }
679
680    /// Novelty score for a bucket: inverse of visit density.
681    /// Higher = more novel / less explored.
682    pub fn novelty_score(&self, bucket: &ContextBucket) -> f32 {
683        if self.total_visits == 0 {
684            return 1.0;
685        }
686
687        let bucket_visits: u64 = self
688            .visit_counts
689            .get(bucket)
690            .map(|arms| arms.values().sum())
691            .unwrap_or(0);
692
693        if bucket_visits == 0 {
694            return 1.0;
695        }
696
697        1.0 - (bucket_visits as f32 / self.total_visits as f32)
698    }
699}
700
701// ═══════════════════════════════════════════════════════════════════
702// Integrated Meta-Learning Engine
703// ═══════════════════════════════════════════════════════════════════
704
705/// Unified meta-learning engine that composes all five improvements.
706///
707/// Drop-in enhancement for the existing DomainExpansionEngine.
708/// Call `record_decision` after each arm selection and `check_plateau`
709/// periodically to get adaptive strategy recommendations.
710#[derive(Debug, Clone, Serialize, Deserialize)]
711pub struct MetaLearningEngine {
712    pub regret: RegretTracker,
713    pub plateau: PlateauDetector,
714    pub pareto: ParetoFront,
715    pub curiosity: CuriosityBonus,
716    /// Per-bucket decaying beta distributions (optional overlay).
717    pub decaying_betas: HashMap<(ContextBucket, ArmId), DecayingBeta>,
718    /// Decay factor for the decaying beta distributions.
719    decay_factor: f32,
720}
721
722impl MetaLearningEngine {
723    /// Create with standard parameters.
724    pub fn new() -> Self {
725        Self {
726            regret: RegretTracker::new(50),
727            plateau: PlateauDetector::new(5, 0.005),
728            pareto: ParetoFront::new(),
729            curiosity: CuriosityBonus::new(1.41),
730            decaying_betas: HashMap::new(),
731            decay_factor: 0.995,
732        }
733    }
734
735    /// Create with custom parameters.
736    pub fn with_config(
737        regret_snapshot_interval: u64,
738        plateau_window: usize,
739        plateau_threshold: f32,
740        exploration_coeff: f32,
741        decay_factor: f32,
742    ) -> Self {
743        Self {
744            regret: RegretTracker::new(regret_snapshot_interval),
745            plateau: PlateauDetector::new(plateau_window, plateau_threshold),
746            pareto: ParetoFront::new(),
747            curiosity: CuriosityBonus::new(exploration_coeff),
748            decaying_betas: HashMap::new(),
749            decay_factor,
750        }
751    }
752
753    /// Record a decision outcome. Call after every arm selection.
754    pub fn record_decision(
755        &mut self,
756        bucket: &ContextBucket,
757        arm: &ArmId,
758        reward: f32,
759    ) {
760        // 1. Track regret
761        self.regret.record(bucket, arm, reward);
762
763        // 2. Update curiosity counts
764        self.curiosity.record_visit(bucket, arm);
765
766        // 3. Update decaying beta for this bucket/arm.
767        //    Avoid tuple clone on hot path when entry exists.
768        let key = (bucket.clone(), arm.clone());
769        if let Some(db) = self.decaying_betas.get_mut(&key) {
770            db.update(reward);
771        } else {
772            let mut db = DecayingBeta::new(self.decay_factor);
773            db.update(reward);
774            self.decaying_betas.insert(key, db);
775        }
776    }
777
778    /// Record a population kernel's multi-objective performance.
779    pub fn record_kernel(
780        &mut self,
781        kernel_id: &str,
782        accuracy: f32,
783        cost: f32,
784        robustness: f32,
785        generation: u32,
786    ) {
787        let point = ParetoPoint {
788            kernel_id: kernel_id.to_string(),
789            // Convention: higher is better, so negate cost
790            objectives: vec![accuracy, -cost, robustness],
791            generation,
792        };
793        self.pareto.insert(point);
794    }
795
796    /// Check the cost curve for plateau and recommend action.
797    pub fn check_plateau(&mut self, points: &[CostCurvePoint]) -> PlateauAction {
798        self.plateau.check(points)
799    }
800
801    /// Get the curiosity-boosted score for an arm.
802    ///
803    /// Combines the Thompson Sampling estimate with an exploration bonus.
804    pub fn boosted_score(
805        &self,
806        bucket: &ContextBucket,
807        arm: &ArmId,
808        thompson_sample: f32,
809    ) -> f32 {
810        let bonus = self.curiosity.bonus(bucket, arm);
811        thompson_sample + bonus
812    }
813
814    /// Get the decaying beta mean for a bucket/arm (if tracked).
815    pub fn decaying_mean(
816        &self,
817        bucket: &ContextBucket,
818        arm: &ArmId,
819    ) -> Option<f32> {
820        let key = (bucket.clone(), arm.clone());
821        self.decaying_betas.get(&key).map(|db| db.mean())
822    }
823
824    /// Comprehensive health check of the learning system.
825    pub fn health_check(&self) -> MetaLearningHealth {
826        let regret_summary = self.regret.summary();
827        let pareto_size = self.pareto.len();
828
829        let is_learning = regret_summary.mean_growth_rate < 0.8;
830        let is_diverse = pareto_size >= 3;
831        let is_exploring = self.curiosity.total_visits > 0;
832
833        MetaLearningHealth {
834            regret: regret_summary,
835            pareto_size,
836            pareto_hypervolume: self.pareto.hypervolume(&[0.0, -1.0, 0.0]),
837            consecutive_plateaus: self.plateau.consecutive_plateaus,
838            total_plateaus: self.plateau.total_plateaus,
839            curiosity_total_visits: self.curiosity.total_visits,
840            decaying_beta_count: self.decaying_betas.len(),
841            is_learning,
842            is_diverse,
843            is_exploring,
844        }
845    }
846}
847
848impl Default for MetaLearningEngine {
849    fn default() -> Self {
850        Self::new()
851    }
852}
853
854/// Health summary of the meta-learning system.
855#[derive(Debug, Clone, Serialize, Deserialize)]
856pub struct MetaLearningHealth {
857    pub regret: RegretSummary,
858    pub pareto_size: usize,
859    pub pareto_hypervolume: f32,
860    pub consecutive_plateaus: u32,
861    pub total_plateaus: u32,
862    pub curiosity_total_visits: u64,
863    pub decaying_beta_count: usize,
864    /// True if regret growth is sublinear (system is learning).
865    pub is_learning: bool,
866    /// True if Pareto front has diverse solutions.
867    pub is_diverse: bool,
868    /// True if curiosity is actively exploring.
869    pub is_exploring: bool,
870}
871
872// ═══════════════════════════════════════════════════════════════════
873// Tests
874// ═══════════════════════════════════════════════════════════════════
875
876#[cfg(test)]
877mod tests {
878    use super::*;
879
880    fn test_bucket(tier: &str, cat: &str) -> ContextBucket {
881        ContextBucket {
882            difficulty_tier: tier.into(),
883            category: cat.into(),
884        }
885    }
886
887    // -- RegretTracker tests --
888
889    #[test]
890    fn test_regret_tracker_empty() {
891        let tracker = RegretTracker::new(10);
892        assert_eq!(tracker.total_regret, 0.0);
893        assert_eq!(tracker.average_regret(), 0.0);
894    }
895
896    #[test]
897    fn test_regret_tracker_optimal_arm() {
898        let mut tracker = RegretTracker::new(10);
899        let bucket = test_bucket("easy", "test");
900        let arm = ArmId("best".into());
901
902        // Always picking the best arm → zero regret
903        for _ in 0..100 {
904            tracker.record(&bucket, &arm, 0.9);
905        }
906
907        assert_eq!(tracker.total_observations, 100);
908        // All same arm, so regret is 0
909        assert!(tracker.total_regret < 1e-6);
910    }
911
912    #[test]
913    fn test_regret_tracker_suboptimal_arm() {
914        let mut tracker = RegretTracker::new(10);
915        let bucket = test_bucket("medium", "test");
916        let good = ArmId("good".into());
917        let bad = ArmId("bad".into());
918
919        // Establish good arm's mean
920        for _ in 0..50 {
921            tracker.record(&bucket, &good, 0.9);
922        }
923
924        // Now pick the bad arm repeatedly → regret accumulates
925        for _ in 0..50 {
926            tracker.record(&bucket, &bad, 0.3);
927        }
928
929        assert!(tracker.total_regret > 0.0);
930        assert!(tracker.average_regret() > 0.0);
931    }
932
933    #[test]
934    fn test_regret_growth_rate() {
935        let mut tracker = RegretTracker::new(5);
936        let bucket = test_bucket("hard", "test");
937        let arm_a = ArmId("a".into());
938        let arm_b = ArmId("b".into());
939
940        for _ in 0..50 {
941            tracker.record(&bucket, &arm_a, 0.8);
942        }
943        for _ in 0..50 {
944            tracker.record(&bucket, &arm_b, 0.4);
945        }
946
947        let rate = tracker.regret_growth_rate(&bucket);
948        assert!(rate.is_some());
949        // Rate should be defined (we have enough observations)
950    }
951
952    #[test]
953    fn test_regret_summary() {
954        let mut tracker = RegretTracker::new(10);
955        let bucket = test_bucket("easy", "algo");
956        let arm = ArmId("test".into());
957
958        for _ in 0..20 {
959            tracker.record(&bucket, &arm, 0.7);
960        }
961
962        let summary = tracker.summary();
963        assert_eq!(summary.total_observations, 20);
964        assert_eq!(summary.bucket_count, 1);
965    }
966
967    // -- DecayingBeta tests --
968
969    #[test]
970    fn test_decaying_beta_initial() {
971        let db = DecayingBeta::new(0.995);
972        assert!((db.mean() - 0.5).abs() < 1e-6); // Uniform prior
973        assert_eq!(db.effective_n, 0.0);
974    }
975
976    #[test]
977    fn test_decaying_beta_update() {
978        let mut db = DecayingBeta::new(0.995);
979
980        for _ in 0..100 {
981            db.update(0.9); // Mostly successes
982        }
983
984        assert!(db.mean() > 0.7); // Should reflect high success rate
985        assert!(db.effective_n > 50.0); // Decayed but substantial
986    }
987
988    #[test]
989    fn test_decaying_beta_adapts() {
990        let mut db = DecayingBeta::new(0.99); // Faster decay
991
992        // First: many successes
993        for _ in 0..100 {
994            db.update(0.95);
995        }
996        let mean_after_good = db.mean();
997        assert!(mean_after_good > 0.8);
998
999        // Then: many failures (environment changed)
1000        for _ in 0..100 {
1001            db.update(0.1);
1002        }
1003        let mean_after_bad = db.mean();
1004
1005        // With decay, it should adapt toward the new distribution
1006        assert!(mean_after_bad < mean_after_good);
1007        assert!(mean_after_bad < 0.5); // Should reflect recent failures
1008    }
1009
1010    #[test]
1011    fn test_decaying_beta_window() {
1012        let db = DecayingBeta::new(0.99);
1013        let window = db.effective_window();
1014        assert!((window - 100.0).abs() < 1.0); // 1/(1-0.99) = 100
1015
1016        let db2 = DecayingBeta::new(0.995);
1017        let window2 = db2.effective_window();
1018        assert!((window2 - 200.0).abs() < 1.0); // 1/(1-0.005) = 200
1019    }
1020
1021    #[test]
1022    fn test_decaying_to_standard() {
1023        let mut db = DecayingBeta::new(0.995);
1024        for _ in 0..10 {
1025            db.update(0.8);
1026        }
1027        let params = db.to_beta_params();
1028        assert!(params.alpha > 1.0);
1029        assert!(params.beta > 1.0);
1030        assert!((params.mean() - db.mean()).abs() < 1e-6);
1031    }
1032
1033    // -- PlateauDetector tests --
1034
1035    #[test]
1036    fn test_plateau_no_data() {
1037        let mut detector = PlateauDetector::new(3, 0.01);
1038        let action = detector.check(&[]);
1039        assert_eq!(action, PlateauAction::Continue);
1040    }
1041
1042    #[test]
1043    fn test_plateau_not_enough_data() {
1044        let mut detector = PlateauDetector::new(3, 0.01);
1045        let points: Vec<CostCurvePoint> = (0..4)
1046            .map(|i| CostCurvePoint {
1047                cycle: i as u64,
1048                accuracy: 0.5 + i as f32 * 0.1,
1049                cost_per_solve: 0.1,
1050                robustness: 0.8,
1051                policy_violations: 0,
1052                timestamp: i as f64,
1053            })
1054            .collect();
1055
1056        let action = detector.check(&points);
1057        assert_eq!(action, PlateauAction::Continue);
1058    }
1059
1060    #[test]
1061    fn test_plateau_detected() {
1062        let mut detector = PlateauDetector::new(3, 0.01);
1063
1064        // Flat accuracy → plateau
1065        let points: Vec<CostCurvePoint> = (0..6)
1066            .map(|i| CostCurvePoint {
1067                cycle: i as u64,
1068                accuracy: 0.80 + (i as f32 * 0.001), // Nearly flat
1069                cost_per_solve: 0.1,
1070                robustness: 0.8,
1071                policy_violations: 0,
1072                timestamp: i as f64,
1073            })
1074            .collect();
1075
1076        let action = detector.check(&points);
1077        assert_ne!(action, PlateauAction::Continue);
1078    }
1079
1080    #[test]
1081    fn test_plateau_improving() {
1082        let mut detector = PlateauDetector::new(3, 0.01);
1083
1084        // Clear improvement → no plateau
1085        let points: Vec<CostCurvePoint> = (0..6)
1086            .map(|i| CostCurvePoint {
1087                cycle: i as u64,
1088                accuracy: 0.50 + i as f32 * 0.08, // Strong improvement
1089                cost_per_solve: 0.1,
1090                robustness: 0.8,
1091                policy_violations: 0,
1092                timestamp: i as f64,
1093            })
1094            .collect();
1095
1096        let action = detector.check(&points);
1097        assert_eq!(action, PlateauAction::Continue);
1098    }
1099
1100    #[test]
1101    fn test_plateau_escalation() {
1102        let mut detector = PlateauDetector::new(3, 0.01);
1103
1104        let flat_points: Vec<CostCurvePoint> = (0..6)
1105            .map(|i| CostCurvePoint {
1106                cycle: i as u64,
1107                accuracy: 0.80,
1108                cost_per_solve: 0.1,
1109                robustness: 0.8,
1110                policy_violations: 0,
1111                timestamp: i as f64,
1112            })
1113            .collect();
1114
1115        assert_eq!(detector.check(&flat_points), PlateauAction::IncreaseExploration);
1116        assert_eq!(detector.check(&flat_points), PlateauAction::TriggerTransfer);
1117        assert_eq!(detector.check(&flat_points), PlateauAction::TriggerTransfer);
1118        assert_eq!(detector.check(&flat_points), PlateauAction::InjectDiversity);
1119    }
1120
1121    #[test]
1122    fn test_learning_velocity() {
1123        let detector = PlateauDetector::new(3, 0.01);
1124
1125        let points: Vec<CostCurvePoint> = (0..6)
1126            .map(|i| CostCurvePoint {
1127                cycle: i as u64,
1128                accuracy: 0.50 + i as f32 * 0.1,
1129                cost_per_solve: 0.1,
1130                robustness: 0.8,
1131                policy_violations: 0,
1132                timestamp: i as f64,
1133            })
1134            .collect();
1135
1136        let velocity = detector.learning_velocity(&points);
1137        assert!(velocity > 0.0); // Should be positive (improving)
1138    }
1139
1140    // -- ParetoFront tests --
1141
1142    #[test]
1143    fn test_pareto_dominates() {
1144        assert!(ParetoFront::dominates(&[0.9, -0.1, 0.8], &[0.8, -0.2, 0.7]));
1145        assert!(!ParetoFront::dominates(&[0.9, -0.3, 0.8], &[0.8, -0.1, 0.7]));
1146        assert!(!ParetoFront::dominates(&[0.9, -0.1, 0.8], &[0.9, -0.1, 0.8])); // Equal
1147    }
1148
1149    #[test]
1150    fn test_pareto_insert_non_dominated() {
1151        let mut front = ParetoFront::new();
1152
1153        // Two non-dominated points (tradeoff: accuracy vs cost)
1154        assert!(front.insert(ParetoPoint {
1155            kernel_id: "a".into(),
1156            objectives: vec![0.9, -0.3, 0.7],
1157            generation: 0,
1158        }));
1159        assert!(front.insert(ParetoPoint {
1160            kernel_id: "b".into(),
1161            objectives: vec![0.7, -0.1, 0.9],
1162            generation: 0,
1163        }));
1164
1165        assert_eq!(front.len(), 2);
1166    }
1167
1168    #[test]
1169    fn test_pareto_insert_dominated() {
1170        let mut front = ParetoFront::new();
1171
1172        front.insert(ParetoPoint {
1173            kernel_id: "good".into(),
1174            objectives: vec![0.9, -0.1, 0.9],
1175            generation: 0,
1176        });
1177
1178        // This is dominated by "good" on all objectives
1179        let added = front.insert(ParetoPoint {
1180            kernel_id: "bad".into(),
1181            objectives: vec![0.5, -0.5, 0.5],
1182            generation: 0,
1183        });
1184
1185        assert!(!added);
1186        assert_eq!(front.len(), 1);
1187    }
1188
1189    #[test]
1190    fn test_pareto_removes_dominated() {
1191        let mut front = ParetoFront::new();
1192
1193        front.insert(ParetoPoint {
1194            kernel_id: "old".into(),
1195            objectives: vec![0.5, -0.3, 0.5],
1196            generation: 0,
1197        });
1198
1199        // New point dominates old
1200        front.insert(ParetoPoint {
1201            kernel_id: "new".into(),
1202            objectives: vec![0.9, -0.1, 0.9],
1203            generation: 1,
1204        });
1205
1206        assert_eq!(front.len(), 1);
1207        assert_eq!(front.front[0].kernel_id, "new");
1208    }
1209
1210    #[test]
1211    fn test_pareto_best_on_objective() {
1212        let mut front = ParetoFront::new();
1213
1214        front.insert(ParetoPoint {
1215            kernel_id: "accurate".into(),
1216            objectives: vec![0.95, -0.5, 0.6],
1217            generation: 0,
1218        });
1219        front.insert(ParetoPoint {
1220            kernel_id: "cheap".into(),
1221            objectives: vec![0.7, -0.05, 0.7],
1222            generation: 0,
1223        });
1224        front.insert(ParetoPoint {
1225            kernel_id: "robust".into(),
1226            objectives: vec![0.8, -0.3, 0.95],
1227            generation: 0,
1228        });
1229
1230        assert_eq!(front.best_on(0).unwrap().kernel_id, "accurate");
1231        assert_eq!(front.best_on(1).unwrap().kernel_id, "cheap"); // -0.05 > -0.5
1232        assert_eq!(front.best_on(2).unwrap().kernel_id, "robust");
1233    }
1234
1235    #[test]
1236    fn test_pareto_spread() {
1237        let mut front = ParetoFront::new();
1238
1239        // Non-dominated tradeoff: a is better on obj0, b is better on obj1.
1240        front.insert(ParetoPoint {
1241            kernel_id: "a".into(),
1242            objectives: vec![0.9, -0.5],
1243            generation: 0,
1244        });
1245        front.insert(ParetoPoint {
1246            kernel_id: "b".into(),
1247            objectives: vec![0.5, -0.1],
1248            generation: 0,
1249        });
1250
1251        assert_eq!(front.len(), 2); // Both should survive (non-dominated)
1252        let spread = front.spread();
1253        assert_eq!(spread.len(), 2);
1254        assert!((spread[0] - 0.4).abs() < 1e-4); // 0.9 - 0.5
1255        assert!((spread[1] - 0.4).abs() < 1e-4); // -0.1 - (-0.5)
1256    }
1257
1258    #[test]
1259    fn test_pareto_hypervolume_2d() {
1260        let mut front = ParetoFront::new();
1261
1262        front.insert(ParetoPoint {
1263            kernel_id: "a".into(),
1264            objectives: vec![1.0, 1.0],
1265            generation: 0,
1266        });
1267
1268        let hv = front.hypervolume(&[0.0, 0.0]);
1269        assert!((hv - 1.0).abs() < 1e-4); // 1x1 rectangle
1270    }
1271
1272    // -- CuriosityBonus tests --
1273
1274    #[test]
1275    fn test_curiosity_bonus_unvisited() {
1276        let curiosity = CuriosityBonus::new(1.41);
1277        let bucket = test_bucket("hard", "novel");
1278        let arm = ArmId("new".into());
1279
1280        let bonus = curiosity.bonus(&bucket, &arm);
1281        assert!(bonus > 0.0); // Should have high bonus for unvisited
1282    }
1283
1284    #[test]
1285    fn test_curiosity_bonus_decays_with_visits() {
1286        let mut curiosity = CuriosityBonus::new(1.41);
1287        let bucket = test_bucket("easy", "test");
1288        let arm = ArmId("a".into());
1289
1290        let bonus_before = curiosity.bonus(&bucket, &arm);
1291
1292        for _ in 0..50 {
1293            curiosity.record_visit(&bucket, &arm);
1294        }
1295
1296        let bonus_after = curiosity.bonus(&bucket, &arm);
1297        assert!(bonus_after < bonus_before); // Bonus should decrease
1298    }
1299
1300    #[test]
1301    fn test_curiosity_novelty_score() {
1302        let mut curiosity = CuriosityBonus::new(1.41);
1303        let explored = test_bucket("easy", "common");
1304        let novel = test_bucket("hard", "rare");
1305        let arm = ArmId("a".into());
1306
1307        for _ in 0..100 {
1308            curiosity.record_visit(&explored, &arm);
1309        }
1310        curiosity.record_visit(&novel, &arm);
1311
1312        let explored_novelty = curiosity.novelty_score(&explored);
1313        let novel_novelty = curiosity.novelty_score(&novel);
1314
1315        assert!(novel_novelty > explored_novelty);
1316    }
1317
1318    // -- MetaLearningEngine integration tests --
1319
1320    #[test]
1321    fn test_meta_engine_creation() {
1322        let engine = MetaLearningEngine::new();
1323        assert_eq!(engine.regret.total_observations, 0);
1324        assert!(engine.pareto.is_empty());
1325        assert_eq!(engine.curiosity.total_visits, 0);
1326    }
1327
1328    #[test]
1329    fn test_meta_engine_record_decision() {
1330        let mut engine = MetaLearningEngine::new();
1331        let bucket = test_bucket("medium", "algo");
1332        let arm = ArmId("greedy".into());
1333
1334        for _ in 0..50 {
1335            engine.record_decision(&bucket, &arm, 0.85);
1336        }
1337
1338        assert_eq!(engine.regret.total_observations, 50);
1339        assert_eq!(engine.curiosity.total_visits, 50);
1340        assert!(engine.decaying_mean(&bucket, &arm).unwrap() > 0.7);
1341    }
1342
1343    #[test]
1344    fn test_meta_engine_boosted_score() {
1345        let mut engine = MetaLearningEngine::new();
1346        let explored = test_bucket("easy", "common");
1347        let novel = test_bucket("hard", "rare");
1348        let arm = ArmId("a".into());
1349
1350        // Explore one bucket heavily
1351        for _ in 0..100 {
1352            engine.record_decision(&explored, &arm, 0.8);
1353        }
1354
1355        let score_explored = engine.boosted_score(&explored, &arm, 0.5);
1356        let score_novel = engine.boosted_score(&novel, &arm, 0.5);
1357
1358        // Novel bucket should get higher boosted score
1359        assert!(score_novel > score_explored);
1360    }
1361
1362    #[test]
1363    fn test_meta_engine_kernel_recording() {
1364        let mut engine = MetaLearningEngine::new();
1365
1366        engine.record_kernel("k1", 0.9, 0.3, 0.7, 0);
1367        engine.record_kernel("k2", 0.7, 0.1, 0.9, 0);
1368        engine.record_kernel("k3", 0.5, 0.5, 0.5, 0); // Dominated by k1
1369
1370        // k1 and k2 are non-dominated; k3 is dominated
1371        assert!(engine.pareto.len() <= 2);
1372    }
1373
1374    #[test]
1375    fn test_meta_engine_health_check() {
1376        let mut engine = MetaLearningEngine::new();
1377        let bucket = test_bucket("medium", "test");
1378        let arm = ArmId("a".into());
1379
1380        for _ in 0..100 {
1381            engine.record_decision(&bucket, &arm, 0.8);
1382        }
1383
1384        let health = engine.health_check();
1385        assert_eq!(health.curiosity_total_visits, 100);
1386        assert!(health.is_exploring);
1387    }
1388
1389    #[test]
1390    fn test_meta_engine_plateau_check() {
1391        let mut engine = MetaLearningEngine::new();
1392
1393        let flat_points: Vec<CostCurvePoint> = (0..10)
1394            .map(|i| CostCurvePoint {
1395                cycle: i as u64,
1396                accuracy: 0.80,
1397                cost_per_solve: 0.1,
1398                robustness: 0.8,
1399                policy_violations: 0,
1400                timestamp: i as f64,
1401            })
1402            .collect();
1403
1404        let action = engine.check_plateau(&flat_points);
1405        assert_ne!(action, PlateauAction::Continue);
1406    }
1407
1408    #[test]
1409    fn test_meta_engine_default() {
1410        let engine = MetaLearningEngine::default();
1411        assert_eq!(engine.curiosity.exploration_coeff, 1.41);
1412    }
1413}