1#![warn(missing_docs)]
45
46pub mod cost_curve;
47pub mod domain;
48pub mod meta_learning;
49pub mod planning;
50pub mod policy_kernel;
51pub mod rust_synthesis;
52pub mod tool_orchestration;
53pub mod transfer;
54
55#[cfg(feature = "rvf")]
59pub mod rvf_bridge;
60
61pub use cost_curve::{
63 AccelerationEntry, AccelerationScoreboard, ConvergenceThresholds, CostCurve, CostCurvePoint,
64 ScoreboardSummary,
65};
66pub use domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task};
67pub use planning::PlanningDomain;
68pub use policy_kernel::{PolicyKernel, PolicyKnobs, PopulationSearch, PopulationStats};
69pub use rust_synthesis::RustSynthesisDomain;
70pub use tool_orchestration::ToolOrchestrationDomain;
71pub use meta_learning::{
72 CuriosityBonus, DecayingBeta, MetaLearningEngine, MetaLearningHealth, ParetoFront,
73 ParetoPoint, PlateauAction, PlateauDetector, RegretSummary, RegretTracker,
74};
75pub use transfer::{
76 ArmId, BetaParams, ContextBucket, DualPathResult, MetaThompsonEngine, TransferPrior,
77 TransferVerification,
78};
79
80use std::collections::HashMap;
81
82pub struct DomainExpansionEngine {
91 domains: HashMap<DomainId, Box<dyn Domain>>,
93 pub thompson: MetaThompsonEngine,
95 pub population: PopulationSearch,
97 pub scoreboard: AccelerationScoreboard,
99 pub meta: MetaLearningEngine,
101 holdouts: HashMap<DomainId, Vec<Task>>,
103 counterexamples: HashMap<DomainId, Vec<(Task, Solution, Evaluation)>>,
105}
106
107impl DomainExpansionEngine {
108 pub fn new() -> Self {
112 let arms = vec![
113 "greedy".into(),
114 "exploratory".into(),
115 "conservative".into(),
116 "speculative".into(),
117 ];
118
119 let mut engine = Self {
120 domains: HashMap::new(),
121 thompson: MetaThompsonEngine::new(arms),
122 population: PopulationSearch::new(8),
123 scoreboard: AccelerationScoreboard::new(),
124 meta: MetaLearningEngine::new(),
125 holdouts: HashMap::new(),
126 counterexamples: HashMap::new(),
127 };
128
129 engine.register_domain(Box::new(RustSynthesisDomain::new()));
131 engine.register_domain(Box::new(PlanningDomain::new()));
132 engine.register_domain(Box::new(ToolOrchestrationDomain::new()));
133
134 engine
135 }
136
137 pub fn register_domain(&mut self, domain: Box<dyn Domain>) {
139 let id = domain.id().clone();
140 self.thompson.init_domain_uniform(id.clone());
141 self.domains.insert(id, domain);
142 }
143
144 pub fn generate_holdouts(&mut self, tasks_per_domain: usize, difficulty: f32) {
146 for (id, domain) in &self.domains {
147 let tasks = domain.generate_tasks(tasks_per_domain, difficulty);
148 self.holdouts.insert(id.clone(), tasks);
149 }
150 }
151
152 pub fn generate_tasks(
154 &self,
155 domain_id: &DomainId,
156 count: usize,
157 difficulty: f32,
158 ) -> Vec<Task> {
159 self.domains
160 .get(domain_id)
161 .map(|d| d.generate_tasks(count, difficulty))
162 .unwrap_or_default()
163 }
164
165 pub fn evaluate_and_record(
167 &mut self,
168 domain_id: &DomainId,
169 task: &Task,
170 solution: &Solution,
171 bucket: ContextBucket,
172 arm: ArmId,
173 ) -> Evaluation {
174 let eval = self
175 .domains
176 .get(domain_id)
177 .map(|d| d.evaluate(task, solution))
178 .unwrap_or_else(|| Evaluation::zero(vec!["Domain not found".into()]));
179
180 self.thompson.record_outcome(
182 domain_id,
183 bucket.clone(),
184 arm.clone(),
185 eval.score,
186 1.0, );
188
189 self.meta.record_decision(&bucket, &arm, eval.score);
191
192 if eval.score < 0.3 {
194 self.counterexamples
195 .entry(domain_id.clone())
196 .or_default()
197 .push((task.clone(), solution.clone(), eval.clone()));
198 }
199
200 eval
201 }
202
203 pub fn embed(&self, domain_id: &DomainId, solution: &Solution) -> Option<DomainEmbedding> {
205 self.domains.get(domain_id).map(|d| d.embed(solution))
206 }
207
208 pub fn initiate_transfer(&mut self, source: &DomainId, target: &DomainId) {
211 if let Some(prior) = self.thompson.extract_prior(source) {
212 self.thompson
213 .init_domain_with_transfer(target.clone(), &prior);
214 }
215 }
216
217 pub fn verify_transfer(
219 &self,
220 source: &DomainId,
221 target: &DomainId,
222 source_before: f32,
223 source_after: f32,
224 target_before: f32,
225 target_after: f32,
226 baseline_cycles: u64,
227 transfer_cycles: u64,
228 ) -> TransferVerification {
229 TransferVerification::verify(
230 source.clone(),
231 target.clone(),
232 source_before,
233 source_after,
234 target_before,
235 target_after,
236 baseline_cycles,
237 transfer_cycles,
238 )
239 }
240
241 pub fn evaluate_population(&mut self) {
243 let holdout_snapshot: HashMap<DomainId, Vec<Task>> = self.holdouts.clone();
244 let domain_ids: Vec<DomainId> = self.domains.keys().cloned().collect();
245
246 for i in 0..self.population.population().len() {
247 for domain_id in &domain_ids {
248 if let Some(holdout_tasks) = holdout_snapshot.get(domain_id) {
249 let mut total_score = 0.0f32;
250 let mut count = 0;
251
252 for task in holdout_tasks {
253 if let Some(domain) = self.domains.get(domain_id) {
254 if let Some(ref_sol) = domain.reference_solution(task) {
255 let eval = domain.evaluate(task, &ref_sol);
256 total_score += eval.score;
257 count += 1;
258 }
259 }
260 }
261
262 let avg_score = if count > 0 {
263 total_score / count as f32
264 } else {
265 0.0
266 };
267
268 if let Some(kernel) = self.population.kernel_mut(i) {
269 kernel.record_score(domain_id.clone(), avg_score, 1.0);
270 }
271 }
272 }
273 }
274 }
275
276 pub fn evolve_population(&mut self) {
278 let gen = self.population.generation();
280 for kernel in self.population.population() {
281 let accuracy = kernel.fitness();
282 let cost = if kernel.cycles > 0 {
283 kernel.total_cost / kernel.cycles as f32
284 } else {
285 0.0
286 };
287 let robustness = if kernel.holdout_scores.len() > 1 {
289 let mean = accuracy;
290 let var: f32 = kernel
291 .holdout_scores
292 .values()
293 .map(|s| (s - mean).powi(2))
294 .sum::<f32>()
295 / kernel.holdout_scores.len() as f32;
296 (1.0 - var.sqrt()).max(0.0)
297 } else {
298 accuracy
299 };
300 self.meta.record_kernel(&kernel.id, accuracy, cost, robustness, gen);
301 }
302
303 self.population.evolve();
304 }
305
306 pub fn best_kernel(&self) -> Option<&PolicyKernel> {
308 self.population.best()
309 }
310
311 pub fn population_stats(&self) -> PopulationStats {
313 self.population.stats()
314 }
315
316 pub fn scoreboard_summary(&self) -> ScoreboardSummary {
318 self.scoreboard.summary()
319 }
320
321 pub fn domain_ids(&self) -> Vec<DomainId> {
323 self.domains.keys().cloned().collect()
324 }
325
326 pub fn counterexamples(
328 &self,
329 domain_id: &DomainId,
330 ) -> &[(Task, Solution, Evaluation)] {
331 self.counterexamples
332 .get(domain_id)
333 .map(|v| v.as_slice())
334 .unwrap_or(&[])
335 }
336
337 pub fn select_arm(
339 &self,
340 domain_id: &DomainId,
341 bucket: &ContextBucket,
342 ) -> Option<ArmId> {
343 let mut rng = rand::thread_rng();
344 self.thompson.select_arm(domain_id, bucket, &mut rng)
345 }
346
347 pub fn should_speculate(
349 &self,
350 domain_id: &DomainId,
351 bucket: &ContextBucket,
352 ) -> bool {
353 self.thompson.is_uncertain(domain_id, bucket, 0.15)
354 }
355
356 pub fn select_arm_curious(
361 &self,
362 domain_id: &DomainId,
363 bucket: &ContextBucket,
364 ) -> Option<ArmId> {
365 let mut rng = rand::thread_rng();
366 let prior = self.thompson.extract_prior(domain_id)?;
368 let arms: Vec<ArmId> = prior
369 .bucket_priors
370 .get(bucket)
371 .map(|m| m.keys().cloned().collect())
372 .unwrap_or_default();
373
374 if arms.is_empty() {
375 return self.thompson.select_arm(domain_id, bucket, &mut rng);
376 }
377
378 let mut best_arm = None;
379 let mut best_score = f32::NEG_INFINITY;
380
381 for arm in &arms {
382 let params = prior.get_prior(bucket, arm);
383 let sample = params.sample(&mut rng);
384 let boosted = self.meta.boosted_score(bucket, arm, sample);
385
386 if boosted > best_score {
387 best_score = boosted;
388 best_arm = Some(arm.clone());
389 }
390 }
391
392 best_arm.or_else(|| self.thompson.select_arm(domain_id, bucket, &mut rng))
393 }
394
395 pub fn meta_health(&self) -> MetaLearningHealth {
397 self.meta.health_check()
398 }
399
400 pub fn check_plateau(
402 &mut self,
403 domain_id: &DomainId,
404 ) -> PlateauAction {
405 if let Some(curve) = self.scoreboard.curves.get(domain_id) {
406 self.meta.check_plateau(&curve.points)
407 } else {
408 PlateauAction::Continue
409 }
410 }
411
412 pub fn regret_summary(&self) -> RegretSummary {
414 self.meta.regret.summary()
415 }
416
417 pub fn pareto_front(&self) -> &ParetoFront {
419 &self.meta.pareto
420 }
421}
422
423impl Default for DomainExpansionEngine {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_engine_creation() {
435 let engine = DomainExpansionEngine::new();
436 let ids = engine.domain_ids();
437 assert_eq!(ids.len(), 3);
438 }
439
440 #[test]
441 fn test_generate_tasks_all_domains() {
442 let engine = DomainExpansionEngine::new();
443 for domain_id in engine.domain_ids() {
444 let tasks = engine.generate_tasks(&domain_id, 5, 0.5);
445 assert_eq!(tasks.len(), 5);
446 }
447 }
448
449 #[test]
450 fn test_arm_selection() {
451 let engine = DomainExpansionEngine::new();
452 let bucket = ContextBucket {
453 difficulty_tier: "medium".into(),
454 category: "general".into(),
455 };
456 for domain_id in engine.domain_ids() {
457 let arm = engine.select_arm(&domain_id, &bucket);
458 assert!(arm.is_some());
459 }
460 }
461
462 #[test]
463 fn test_evaluate_and_record() {
464 let mut engine = DomainExpansionEngine::new();
465 let domain_id = DomainId("rust_synthesis".into());
466 let tasks = engine.generate_tasks(&domain_id, 1, 0.3);
467 let task = &tasks[0];
468
469 let solution = Solution {
470 task_id: task.id.clone(),
471 content: "fn double(values: &[i64]) -> Vec<i64> { values.iter().map(|&x| x * 2).collect() }".into(),
472 data: serde_json::Value::Null,
473 };
474
475 let bucket = ContextBucket {
476 difficulty_tier: "easy".into(),
477 category: "transform".into(),
478 };
479 let arm = ArmId("greedy".into());
480
481 let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
482 assert!(eval.score >= 0.0 && eval.score <= 1.0);
483 }
484
485 #[test]
486 fn test_cross_domain_embedding() {
487 let engine = DomainExpansionEngine::new();
488
489 let rust_sol = Solution {
490 task_id: "rust".into(),
491 content: "fn foo() { for i in 0..10 { if i > 5 { } } }".into(),
492 data: serde_json::Value::Null,
493 };
494
495 let plan_sol = Solution {
496 task_id: "plan".into(),
497 content: "allocate cpu then schedule parallel jobs".into(),
498 data: serde_json::json!({"steps": []}),
499 };
500
501 let rust_emb = engine
502 .embed(&DomainId("rust_synthesis".into()), &rust_sol)
503 .unwrap();
504 let plan_emb = engine
505 .embed(&DomainId("structured_planning".into()), &plan_sol)
506 .unwrap();
507
508 assert_eq!(rust_emb.dim, plan_emb.dim);
510
511 let sim = rust_emb.cosine_similarity(&plan_emb);
513 assert!(sim >= -1.0 && sim <= 1.0);
514 }
515
516 #[test]
517 fn test_transfer_flow() {
518 let mut engine = DomainExpansionEngine::new();
519 let source = DomainId("rust_synthesis".into());
520 let target = DomainId("structured_planning".into());
521
522 let bucket = ContextBucket {
524 difficulty_tier: "medium".into(),
525 category: "algorithm".into(),
526 };
527
528 for _ in 0..30 {
529 engine.thompson.record_outcome(
530 &source,
531 bucket.clone(),
532 ArmId("greedy".into()),
533 0.85,
534 1.0,
535 );
536 }
537
538 engine.initiate_transfer(&source, &target);
540
541 let verification = engine.verify_transfer(
543 &source,
544 &target,
545 0.85, 0.845, 0.3, 0.7, 100, 45, );
552
553 assert!(verification.promotable);
554 assert!(verification.acceleration_factor > 1.0);
555 }
556
557 #[test]
558 fn test_population_evolution() {
559 let mut engine = DomainExpansionEngine::new();
560 engine.generate_holdouts(3, 0.3);
561 engine.evaluate_population();
562
563 let stats_before = engine.population_stats();
564 assert_eq!(stats_before.generation, 0);
565
566 engine.evolve_population();
567 let stats_after = engine.population_stats();
568 assert_eq!(stats_after.generation, 1);
569 }
570
571 #[test]
572 fn test_speculation_trigger() {
573 let engine = DomainExpansionEngine::new();
574 let bucket = ContextBucket {
575 difficulty_tier: "hard".into(),
576 category: "unknown".into(),
577 };
578
579 assert!(engine.should_speculate(
581 &DomainId("rust_synthesis".into()),
582 &bucket,
583 ));
584 }
585
586 #[test]
587 fn test_counterexample_tracking() {
588 let mut engine = DomainExpansionEngine::new();
589 let domain_id = DomainId("rust_synthesis".into());
590 let tasks = engine.generate_tasks(&domain_id, 1, 0.9);
591 let task = &tasks[0];
592
593 let solution = Solution {
595 task_id: task.id.clone(),
596 content: "".into(), data: serde_json::Value::Null,
598 };
599
600 let bucket = ContextBucket {
601 difficulty_tier: "hard".into(),
602 category: "algorithm".into(),
603 };
604 let arm = ArmId("speculative".into());
605
606 let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm);
607 assert!(eval.score < 0.3);
608
609 assert!(!engine.counterexamples(&domain_id).is_empty());
611 }
612}