Skip to main content

ruvector_domain_expansion/
policy_kernel.rs

1//! PolicyKernel: Population-Based Policy Search
2//!
3//! Run a small population of policy variants in parallel.
4//! Each variant changes a small set of knobs:
5//! - skip mode policy
6//! - prepass mode
7//! - speculation trigger thresholds
8//! - budget allocation
9//!
10//! Selection: keep top performers on holdouts, mutate knobs, repeat.
11//! Only merge deltas that pass replay-verify.
12
13use crate::domain::DomainId;
14use rand::Rng;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18/// Configuration knobs that a PolicyKernel can tune.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PolicyKnobs {
21    /// Whether to skip low-value operations.
22    pub skip_mode: bool,
23    /// Run a cheaper prepass before full execution.
24    pub prepass_enabled: bool,
25    /// Threshold for triggering speculative dual-path [0.0, 1.0].
26    pub speculation_threshold: f32,
27    /// Budget fraction allocated to exploration vs exploitation [0.0, 1.0].
28    pub exploration_budget: f32,
29    /// Maximum retries on failure.
30    pub max_retries: u32,
31    /// Batch size for parallel evaluation.
32    pub batch_size: usize,
33    /// Cost decay factor for EMA.
34    pub cost_decay: f32,
35    /// Minimum confidence to skip uncertainty check.
36    pub confidence_floor: f32,
37}
38
39impl PolicyKnobs {
40    /// Sensible defaults.
41    pub fn default_knobs() -> Self {
42        Self {
43            skip_mode: false,
44            prepass_enabled: true,
45            speculation_threshold: 0.15,
46            exploration_budget: 0.2,
47            max_retries: 2,
48            batch_size: 8,
49            cost_decay: 0.9,
50            confidence_floor: 0.7,
51        }
52    }
53
54    /// Mutate knobs with small random perturbations.
55    pub fn mutate(&self, rng: &mut impl Rng, mutation_rate: f32) -> Self {
56        let mut knobs = self.clone();
57
58        if rng.gen::<f32>() < mutation_rate {
59            knobs.skip_mode = !knobs.skip_mode;
60        }
61        if rng.gen::<f32>() < mutation_rate {
62            knobs.prepass_enabled = !knobs.prepass_enabled;
63        }
64        if rng.gen::<f32>() < mutation_rate {
65            let delta: f32 = rng.gen_range(-0.1..0.1);
66            knobs.speculation_threshold = (knobs.speculation_threshold + delta).clamp(0.01, 0.5);
67        }
68        if rng.gen::<f32>() < mutation_rate {
69            let delta: f32 = rng.gen_range(-0.1..0.1);
70            knobs.exploration_budget = (knobs.exploration_budget + delta).clamp(0.01, 0.5);
71        }
72        if rng.gen::<f32>() < mutation_rate {
73            knobs.max_retries = rng.gen_range(0..5);
74        }
75        if rng.gen::<f32>() < mutation_rate {
76            knobs.batch_size = rng.gen_range(1..32);
77        }
78        if rng.gen::<f32>() < mutation_rate {
79            let delta: f32 = rng.gen_range(-0.05..0.05);
80            knobs.cost_decay = (knobs.cost_decay + delta).clamp(0.5, 0.99);
81        }
82        if rng.gen::<f32>() < mutation_rate {
83            let delta: f32 = rng.gen_range(-0.1..0.1);
84            knobs.confidence_floor = (knobs.confidence_floor + delta).clamp(0.3, 0.95);
85        }
86
87        knobs
88    }
89
90    /// Crossover two parent knobs to produce a child.
91    pub fn crossover(&self, other: &PolicyKnobs, rng: &mut impl Rng) -> Self {
92        Self {
93            skip_mode: if rng.gen() { self.skip_mode } else { other.skip_mode },
94            prepass_enabled: if rng.gen() {
95                self.prepass_enabled
96            } else {
97                other.prepass_enabled
98            },
99            speculation_threshold: if rng.gen() {
100                self.speculation_threshold
101            } else {
102                other.speculation_threshold
103            },
104            exploration_budget: if rng.gen() {
105                self.exploration_budget
106            } else {
107                other.exploration_budget
108            },
109            max_retries: if rng.gen() {
110                self.max_retries
111            } else {
112                other.max_retries
113            },
114            batch_size: if rng.gen() {
115                self.batch_size
116            } else {
117                other.batch_size
118            },
119            cost_decay: if rng.gen() {
120                self.cost_decay
121            } else {
122                other.cost_decay
123            },
124            confidence_floor: if rng.gen() {
125                self.confidence_floor
126            } else {
127                other.confidence_floor
128            },
129        }
130    }
131}
132
133/// A PolicyKernel is a versioned policy configuration with performance history.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct PolicyKernel {
136    /// Unique identifier.
137    pub id: String,
138    /// Configuration knobs.
139    pub knobs: PolicyKnobs,
140    /// Performance on holdout tasks (domain_id -> score).
141    pub holdout_scores: HashMap<DomainId, f32>,
142    /// Total cost incurred.
143    pub total_cost: f32,
144    /// Number of evaluation cycles.
145    pub cycles: u64,
146    /// Generation (0 = initial, increments on mutation).
147    pub generation: u32,
148    /// Parent kernel ID (for lineage tracking).
149    pub parent_id: Option<String>,
150    /// Whether this kernel has been verified via replay.
151    pub replay_verified: bool,
152}
153
154impl PolicyKernel {
155    /// Create a new kernel with default knobs.
156    pub fn new(id: String) -> Self {
157        Self {
158            id,
159            knobs: PolicyKnobs::default_knobs(),
160            holdout_scores: HashMap::new(),
161            total_cost: 0.0,
162            cycles: 0,
163            generation: 0,
164            parent_id: None,
165            replay_verified: false,
166        }
167    }
168
169    /// Create a mutated child kernel.
170    pub fn mutate(&self, child_id: String, rng: &mut impl Rng) -> Self {
171        Self {
172            id: child_id,
173            knobs: self.knobs.mutate(rng, 0.3),
174            holdout_scores: HashMap::new(),
175            total_cost: 0.0,
176            cycles: 0,
177            generation: self.generation + 1,
178            parent_id: Some(self.id.clone()),
179            replay_verified: false,
180        }
181    }
182
183    /// Record a holdout score for a domain.
184    pub fn record_score(&mut self, domain_id: DomainId, score: f32, cost: f32) {
185        self.holdout_scores.insert(domain_id, score);
186        self.total_cost += cost;
187        self.cycles += 1;
188    }
189
190    /// Fitness: average holdout score across all evaluated domains.
191    pub fn fitness(&self) -> f32 {
192        if self.holdout_scores.is_empty() {
193            return 0.0;
194        }
195        let total: f32 = self.holdout_scores.values().sum();
196        total / self.holdout_scores.len() as f32
197    }
198
199    /// Cost-adjusted fitness: penalizes expensive kernels.
200    pub fn cost_adjusted_fitness(&self) -> f32 {
201        let raw = self.fitness();
202        let cost_penalty = (self.total_cost / self.cycles.max(1) as f32).min(1.0);
203        raw * (1.0 - cost_penalty * 0.3) // 30% weight on cost
204    }
205}
206
207/// Population-based policy search engine.
208#[derive(Clone)]
209pub struct PopulationSearch {
210    /// Current population of kernels.
211    population: Vec<PolicyKernel>,
212    /// Population size.
213    pop_size: usize,
214    /// Best kernel seen so far.
215    best_kernel: Option<PolicyKernel>,
216    /// Generation counter.
217    generation: u32,
218}
219
220impl PopulationSearch {
221    /// Create a new population search with initial random population.
222    pub fn new(pop_size: usize) -> Self {
223        let mut rng = rand::thread_rng();
224        let population: Vec<PolicyKernel> = (0..pop_size)
225            .map(|i| {
226                let mut kernel = PolicyKernel::new(format!("kernel_g0_{}", i));
227                // Random initial knobs
228                kernel.knobs = PolicyKnobs::default_knobs().mutate(&mut rng, 0.8);
229                kernel
230            })
231            .collect();
232
233        Self {
234            population,
235            pop_size,
236            best_kernel: None,
237            generation: 0,
238        }
239    }
240
241    /// Get current population for evaluation.
242    pub fn population(&self) -> &[PolicyKernel] {
243        &self.population
244    }
245
246    /// Get mutable reference to a kernel by index.
247    pub fn kernel_mut(&mut self, index: usize) -> Option<&mut PolicyKernel> {
248        self.population.get_mut(index)
249    }
250
251    /// Evolve to next generation: select top performers, mutate, fill population.
252    pub fn evolve(&mut self) {
253        let mut rng = rand::thread_rng();
254        self.generation += 1;
255
256        // Sort by cost-adjusted fitness (descending)
257        self.population
258            .sort_by(|a, b| {
259                b.cost_adjusted_fitness()
260                    .partial_cmp(&a.cost_adjusted_fitness())
261                    .unwrap_or(std::cmp::Ordering::Equal)
262            });
263
264        // Track best
265        if let Some(best) = self.population.first() {
266            if self
267                .best_kernel
268                .as_ref()
269                .map_or(true, |b| best.fitness() > b.fitness())
270            {
271                self.best_kernel = Some(best.clone());
272            }
273        }
274
275        // Elite selection: keep top 25%
276        let elite_count = (self.pop_size / 4).max(1);
277        let elites: Vec<PolicyKernel> = self.population[..elite_count].to_vec();
278
279        // Build next generation
280        let mut next_gen = Vec::with_capacity(self.pop_size);
281
282        // Keep elites
283        for elite in &elites {
284            let mut kept = elite.clone();
285            kept.id = format!("kernel_g{}_{}", self.generation, next_gen.len());
286            kept.holdout_scores.clear();
287            kept.total_cost = 0.0;
288            kept.cycles = 0;
289            next_gen.push(kept);
290        }
291
292        // Fill rest with mutations and crossovers
293        while next_gen.len() < self.pop_size {
294            let parent_idx = rng.gen_range(0..elites.len());
295            let child_id = format!("kernel_g{}_{}", self.generation, next_gen.len());
296
297            let child = if rng.gen::<f32>() < 0.3 && elites.len() > 1 {
298                // Crossover
299                let other_idx = (parent_idx + 1 + rng.gen_range(0..elites.len() - 1)) % elites.len();
300                let mut child = PolicyKernel::new(child_id);
301                child.knobs = elites[parent_idx]
302                    .knobs
303                    .crossover(&elites[other_idx].knobs, &mut rng);
304                child.generation = self.generation;
305                child.parent_id = Some(elites[parent_idx].id.clone());
306                child
307            } else {
308                // Mutation
309                elites[parent_idx].mutate(child_id, &mut rng)
310            };
311
312            next_gen.push(child);
313        }
314
315        self.population = next_gen;
316    }
317
318    /// Get the best kernel found so far.
319    pub fn best(&self) -> Option<&PolicyKernel> {
320        self.best_kernel.as_ref()
321    }
322
323    /// Current generation number.
324    pub fn generation(&self) -> u32 {
325        self.generation
326    }
327
328    /// Get fitness statistics for the current population.
329    pub fn stats(&self) -> PopulationStats {
330        let fitnesses: Vec<f32> = self.population.iter().map(|k| k.fitness()).collect();
331        let mean = fitnesses.iter().sum::<f32>() / fitnesses.len().max(1) as f32;
332        let max = fitnesses
333            .iter()
334            .cloned()
335            .fold(f32::NEG_INFINITY, f32::max);
336        let min = fitnesses.iter().cloned().fold(f32::INFINITY, f32::min);
337        let variance = fitnesses.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
338            / fitnesses.len().max(1) as f32;
339
340        PopulationStats {
341            generation: self.generation,
342            pop_size: self.population.len(),
343            mean_fitness: mean,
344            max_fitness: max,
345            min_fitness: min,
346            fitness_variance: variance,
347            best_ever_fitness: self.best_kernel.as_ref().map(|k| k.fitness()).unwrap_or(0.0),
348        }
349    }
350}
351
352/// Statistics about the current population.
353#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct PopulationStats {
355    pub generation: u32,
356    pub pop_size: usize,
357    pub mean_fitness: f32,
358    pub max_fitness: f32,
359    pub min_fitness: f32,
360    pub fitness_variance: f32,
361    pub best_ever_fitness: f32,
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_policy_knobs_default() {
370        let knobs = PolicyKnobs::default_knobs();
371        assert!(!knobs.skip_mode);
372        assert!(knobs.prepass_enabled);
373        assert!(knobs.speculation_threshold > 0.0);
374    }
375
376    #[test]
377    fn test_policy_knobs_mutate() {
378        let knobs = PolicyKnobs::default_knobs();
379        let mut rng = rand::thread_rng();
380        let mutated = knobs.mutate(&mut rng, 1.0); // high mutation rate
381        // At least something should differ (probabilistically)
382        // Can't guarantee due to randomness, but bounds should hold
383        assert!(mutated.speculation_threshold >= 0.01 && mutated.speculation_threshold <= 0.5);
384        assert!(mutated.exploration_budget >= 0.01 && mutated.exploration_budget <= 0.5);
385    }
386
387    #[test]
388    fn test_policy_kernel_fitness() {
389        let mut kernel = PolicyKernel::new("test".into());
390        assert_eq!(kernel.fitness(), 0.0);
391
392        kernel.record_score(DomainId("d1".into()), 0.8, 1.0);
393        kernel.record_score(DomainId("d2".into()), 0.6, 1.0);
394        assert!((kernel.fitness() - 0.7).abs() < 1e-6);
395    }
396
397    #[test]
398    fn test_population_search_evolve() {
399        let mut search = PopulationSearch::new(8);
400        assert_eq!(search.population().len(), 8);
401
402        // Simulate evaluation
403        for i in 0..8 {
404            if let Some(kernel) = search.kernel_mut(i) {
405                let score = 0.3 + (i as f32) * 0.08;
406                kernel.record_score(DomainId("test".into()), score, 1.0);
407            }
408        }
409
410        search.evolve();
411        assert_eq!(search.population().len(), 8);
412        assert_eq!(search.generation(), 1);
413        assert!(search.best().is_some());
414    }
415
416    #[test]
417    fn test_population_stats() {
418        let mut search = PopulationSearch::new(4);
419
420        for i in 0..4 {
421            if let Some(kernel) = search.kernel_mut(i) {
422                kernel.record_score(DomainId("test".into()), (i as f32) * 0.25, 1.0);
423            }
424        }
425
426        let stats = search.stats();
427        assert_eq!(stats.pop_size, 4);
428        assert!(stats.max_fitness >= stats.min_fitness);
429        assert!(stats.mean_fitness >= stats.min_fitness);
430        assert!(stats.mean_fitness <= stats.max_fitness);
431    }
432
433    #[test]
434    fn test_crossover() {
435        let a = PolicyKnobs {
436            skip_mode: true,
437            prepass_enabled: false,
438            speculation_threshold: 0.1,
439            exploration_budget: 0.1,
440            max_retries: 1,
441            batch_size: 4,
442            cost_decay: 0.8,
443            confidence_floor: 0.5,
444        };
445        let b = PolicyKnobs {
446            skip_mode: false,
447            prepass_enabled: true,
448            speculation_threshold: 0.4,
449            exploration_budget: 0.4,
450            max_retries: 4,
451            batch_size: 16,
452            cost_decay: 0.95,
453            confidence_floor: 0.9,
454        };
455
456        let mut rng = rand::thread_rng();
457        let child = a.crossover(&b, &mut rng);
458
459        // Child values should come from one parent or the other
460        assert!(child.max_retries == 1 || child.max_retries == 4);
461        assert!(child.batch_size == 4 || child.batch_size == 16);
462    }
463}