ruvector_domain_expansion/
policy_kernel.rs1use crate::domain::DomainId;
14use rand::Rng;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PolicyKnobs {
21 pub skip_mode: bool,
23 pub prepass_enabled: bool,
25 pub speculation_threshold: f32,
27 pub exploration_budget: f32,
29 pub max_retries: u32,
31 pub batch_size: usize,
33 pub cost_decay: f32,
35 pub confidence_floor: f32,
37}
38
39impl PolicyKnobs {
40 pub fn default_knobs() -> Self {
42 Self {
43 skip_mode: false,
44 prepass_enabled: true,
45 speculation_threshold: 0.15,
46 exploration_budget: 0.2,
47 max_retries: 2,
48 batch_size: 8,
49 cost_decay: 0.9,
50 confidence_floor: 0.7,
51 }
52 }
53
54 pub fn mutate(&self, rng: &mut impl Rng, mutation_rate: f32) -> Self {
56 let mut knobs = self.clone();
57
58 if rng.gen::<f32>() < mutation_rate {
59 knobs.skip_mode = !knobs.skip_mode;
60 }
61 if rng.gen::<f32>() < mutation_rate {
62 knobs.prepass_enabled = !knobs.prepass_enabled;
63 }
64 if rng.gen::<f32>() < mutation_rate {
65 let delta: f32 = rng.gen_range(-0.1..0.1);
66 knobs.speculation_threshold = (knobs.speculation_threshold + delta).clamp(0.01, 0.5);
67 }
68 if rng.gen::<f32>() < mutation_rate {
69 let delta: f32 = rng.gen_range(-0.1..0.1);
70 knobs.exploration_budget = (knobs.exploration_budget + delta).clamp(0.01, 0.5);
71 }
72 if rng.gen::<f32>() < mutation_rate {
73 knobs.max_retries = rng.gen_range(0..5);
74 }
75 if rng.gen::<f32>() < mutation_rate {
76 knobs.batch_size = rng.gen_range(1..32);
77 }
78 if rng.gen::<f32>() < mutation_rate {
79 let delta: f32 = rng.gen_range(-0.05..0.05);
80 knobs.cost_decay = (knobs.cost_decay + delta).clamp(0.5, 0.99);
81 }
82 if rng.gen::<f32>() < mutation_rate {
83 let delta: f32 = rng.gen_range(-0.1..0.1);
84 knobs.confidence_floor = (knobs.confidence_floor + delta).clamp(0.3, 0.95);
85 }
86
87 knobs
88 }
89
90 pub fn crossover(&self, other: &PolicyKnobs, rng: &mut impl Rng) -> Self {
92 Self {
93 skip_mode: if rng.gen() { self.skip_mode } else { other.skip_mode },
94 prepass_enabled: if rng.gen() {
95 self.prepass_enabled
96 } else {
97 other.prepass_enabled
98 },
99 speculation_threshold: if rng.gen() {
100 self.speculation_threshold
101 } else {
102 other.speculation_threshold
103 },
104 exploration_budget: if rng.gen() {
105 self.exploration_budget
106 } else {
107 other.exploration_budget
108 },
109 max_retries: if rng.gen() {
110 self.max_retries
111 } else {
112 other.max_retries
113 },
114 batch_size: if rng.gen() {
115 self.batch_size
116 } else {
117 other.batch_size
118 },
119 cost_decay: if rng.gen() {
120 self.cost_decay
121 } else {
122 other.cost_decay
123 },
124 confidence_floor: if rng.gen() {
125 self.confidence_floor
126 } else {
127 other.confidence_floor
128 },
129 }
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct PolicyKernel {
136 pub id: String,
138 pub knobs: PolicyKnobs,
140 pub holdout_scores: HashMap<DomainId, f32>,
142 pub total_cost: f32,
144 pub cycles: u64,
146 pub generation: u32,
148 pub parent_id: Option<String>,
150 pub replay_verified: bool,
152}
153
154impl PolicyKernel {
155 pub fn new(id: String) -> Self {
157 Self {
158 id,
159 knobs: PolicyKnobs::default_knobs(),
160 holdout_scores: HashMap::new(),
161 total_cost: 0.0,
162 cycles: 0,
163 generation: 0,
164 parent_id: None,
165 replay_verified: false,
166 }
167 }
168
169 pub fn mutate(&self, child_id: String, rng: &mut impl Rng) -> Self {
171 Self {
172 id: child_id,
173 knobs: self.knobs.mutate(rng, 0.3),
174 holdout_scores: HashMap::new(),
175 total_cost: 0.0,
176 cycles: 0,
177 generation: self.generation + 1,
178 parent_id: Some(self.id.clone()),
179 replay_verified: false,
180 }
181 }
182
183 pub fn record_score(&mut self, domain_id: DomainId, score: f32, cost: f32) {
185 self.holdout_scores.insert(domain_id, score);
186 self.total_cost += cost;
187 self.cycles += 1;
188 }
189
190 pub fn fitness(&self) -> f32 {
192 if self.holdout_scores.is_empty() {
193 return 0.0;
194 }
195 let total: f32 = self.holdout_scores.values().sum();
196 total / self.holdout_scores.len() as f32
197 }
198
199 pub fn cost_adjusted_fitness(&self) -> f32 {
201 let raw = self.fitness();
202 let cost_penalty = (self.total_cost / self.cycles.max(1) as f32).min(1.0);
203 raw * (1.0 - cost_penalty * 0.3) }
205}
206
207#[derive(Clone)]
209pub struct PopulationSearch {
210 population: Vec<PolicyKernel>,
212 pop_size: usize,
214 best_kernel: Option<PolicyKernel>,
216 generation: u32,
218}
219
220impl PopulationSearch {
221 pub fn new(pop_size: usize) -> Self {
223 let mut rng = rand::thread_rng();
224 let population: Vec<PolicyKernel> = (0..pop_size)
225 .map(|i| {
226 let mut kernel = PolicyKernel::new(format!("kernel_g0_{}", i));
227 kernel.knobs = PolicyKnobs::default_knobs().mutate(&mut rng, 0.8);
229 kernel
230 })
231 .collect();
232
233 Self {
234 population,
235 pop_size,
236 best_kernel: None,
237 generation: 0,
238 }
239 }
240
241 pub fn population(&self) -> &[PolicyKernel] {
243 &self.population
244 }
245
246 pub fn kernel_mut(&mut self, index: usize) -> Option<&mut PolicyKernel> {
248 self.population.get_mut(index)
249 }
250
251 pub fn evolve(&mut self) {
253 let mut rng = rand::thread_rng();
254 self.generation += 1;
255
256 self.population
258 .sort_by(|a, b| {
259 b.cost_adjusted_fitness()
260 .partial_cmp(&a.cost_adjusted_fitness())
261 .unwrap_or(std::cmp::Ordering::Equal)
262 });
263
264 if let Some(best) = self.population.first() {
266 if self
267 .best_kernel
268 .as_ref()
269 .map_or(true, |b| best.fitness() > b.fitness())
270 {
271 self.best_kernel = Some(best.clone());
272 }
273 }
274
275 let elite_count = (self.pop_size / 4).max(1);
277 let elites: Vec<PolicyKernel> = self.population[..elite_count].to_vec();
278
279 let mut next_gen = Vec::with_capacity(self.pop_size);
281
282 for elite in &elites {
284 let mut kept = elite.clone();
285 kept.id = format!("kernel_g{}_{}", self.generation, next_gen.len());
286 kept.holdout_scores.clear();
287 kept.total_cost = 0.0;
288 kept.cycles = 0;
289 next_gen.push(kept);
290 }
291
292 while next_gen.len() < self.pop_size {
294 let parent_idx = rng.gen_range(0..elites.len());
295 let child_id = format!("kernel_g{}_{}", self.generation, next_gen.len());
296
297 let child = if rng.gen::<f32>() < 0.3 && elites.len() > 1 {
298 let other_idx = (parent_idx + 1 + rng.gen_range(0..elites.len() - 1)) % elites.len();
300 let mut child = PolicyKernel::new(child_id);
301 child.knobs = elites[parent_idx]
302 .knobs
303 .crossover(&elites[other_idx].knobs, &mut rng);
304 child.generation = self.generation;
305 child.parent_id = Some(elites[parent_idx].id.clone());
306 child
307 } else {
308 elites[parent_idx].mutate(child_id, &mut rng)
310 };
311
312 next_gen.push(child);
313 }
314
315 self.population = next_gen;
316 }
317
318 pub fn best(&self) -> Option<&PolicyKernel> {
320 self.best_kernel.as_ref()
321 }
322
323 pub fn generation(&self) -> u32 {
325 self.generation
326 }
327
328 pub fn stats(&self) -> PopulationStats {
330 let fitnesses: Vec<f32> = self.population.iter().map(|k| k.fitness()).collect();
331 let mean = fitnesses.iter().sum::<f32>() / fitnesses.len().max(1) as f32;
332 let max = fitnesses
333 .iter()
334 .cloned()
335 .fold(f32::NEG_INFINITY, f32::max);
336 let min = fitnesses.iter().cloned().fold(f32::INFINITY, f32::min);
337 let variance = fitnesses.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
338 / fitnesses.len().max(1) as f32;
339
340 PopulationStats {
341 generation: self.generation,
342 pop_size: self.population.len(),
343 mean_fitness: mean,
344 max_fitness: max,
345 min_fitness: min,
346 fitness_variance: variance,
347 best_ever_fitness: self.best_kernel.as_ref().map(|k| k.fitness()).unwrap_or(0.0),
348 }
349 }
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct PopulationStats {
355 pub generation: u32,
356 pub pop_size: usize,
357 pub mean_fitness: f32,
358 pub max_fitness: f32,
359 pub min_fitness: f32,
360 pub fitness_variance: f32,
361 pub best_ever_fitness: f32,
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_policy_knobs_default() {
370 let knobs = PolicyKnobs::default_knobs();
371 assert!(!knobs.skip_mode);
372 assert!(knobs.prepass_enabled);
373 assert!(knobs.speculation_threshold > 0.0);
374 }
375
376 #[test]
377 fn test_policy_knobs_mutate() {
378 let knobs = PolicyKnobs::default_knobs();
379 let mut rng = rand::thread_rng();
380 let mutated = knobs.mutate(&mut rng, 1.0); assert!(mutated.speculation_threshold >= 0.01 && mutated.speculation_threshold <= 0.5);
384 assert!(mutated.exploration_budget >= 0.01 && mutated.exploration_budget <= 0.5);
385 }
386
387 #[test]
388 fn test_policy_kernel_fitness() {
389 let mut kernel = PolicyKernel::new("test".into());
390 assert_eq!(kernel.fitness(), 0.0);
391
392 kernel.record_score(DomainId("d1".into()), 0.8, 1.0);
393 kernel.record_score(DomainId("d2".into()), 0.6, 1.0);
394 assert!((kernel.fitness() - 0.7).abs() < 1e-6);
395 }
396
397 #[test]
398 fn test_population_search_evolve() {
399 let mut search = PopulationSearch::new(8);
400 assert_eq!(search.population().len(), 8);
401
402 for i in 0..8 {
404 if let Some(kernel) = search.kernel_mut(i) {
405 let score = 0.3 + (i as f32) * 0.08;
406 kernel.record_score(DomainId("test".into()), score, 1.0);
407 }
408 }
409
410 search.evolve();
411 assert_eq!(search.population().len(), 8);
412 assert_eq!(search.generation(), 1);
413 assert!(search.best().is_some());
414 }
415
416 #[test]
417 fn test_population_stats() {
418 let mut search = PopulationSearch::new(4);
419
420 for i in 0..4 {
421 if let Some(kernel) = search.kernel_mut(i) {
422 kernel.record_score(DomainId("test".into()), (i as f32) * 0.25, 1.0);
423 }
424 }
425
426 let stats = search.stats();
427 assert_eq!(stats.pop_size, 4);
428 assert!(stats.max_fitness >= stats.min_fitness);
429 assert!(stats.mean_fitness >= stats.min_fitness);
430 assert!(stats.mean_fitness <= stats.max_fitness);
431 }
432
433 #[test]
434 fn test_crossover() {
435 let a = PolicyKnobs {
436 skip_mode: true,
437 prepass_enabled: false,
438 speculation_threshold: 0.1,
439 exploration_budget: 0.1,
440 max_retries: 1,
441 batch_size: 4,
442 cost_decay: 0.8,
443 confidence_floor: 0.5,
444 };
445 let b = PolicyKnobs {
446 skip_mode: false,
447 prepass_enabled: true,
448 speculation_threshold: 0.4,
449 exploration_budget: 0.4,
450 max_retries: 4,
451 batch_size: 16,
452 cost_decay: 0.95,
453 confidence_floor: 0.9,
454 };
455
456 let mut rng = rand::thread_rng();
457 let child = a.crossover(&b, &mut rng);
458
459 assert!(child.max_retries == 1 || child.max_retries == 4);
461 assert!(child.batch_size == 4 || child.batch_size == 16);
462 }
463}