1use crate::ast_mcts::{AstMctsOracle, MctsResult, RuleId, mcts_search};
29use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
30use crate::lineage::Lineage;
31use crate::search::{EvalCandidate, SearchAlgorithm, fitness_cmp};
32use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
33use rand::RngCore;
34use rand::rngs::StdRng;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38pub const DEFAULT_MCTS_BUDGET: u64 = 64;
40
41pub const DEFAULT_UCB1_C: f64 = std::f64::consts::SQRT_2;
43
44struct InlineOracle<'a> {
52 candidates: &'a mut Vec<String>,
54 prior_bypass: bool,
56 jitter: u64,
58}
59
60impl<'a> AstMctsOracle for InlineOracle<'a> {
61 fn eval(&mut self, candidate: &str) -> bool {
62 self.candidates.push(candidate.to_string());
63 self.jitter = self
64 .jitter
65 .wrapping_mul(6_364_136_223_846_793_005)
66 .wrapping_add(1_442_695_040_888_963_407);
67 if self.prior_bypass {
68 true
71 } else {
72 true
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct AstMctsAlgorithm {
91 best: Chromosome,
93 gene_pool: GenePool,
95 generation: u32,
97 eval_counter: u64,
99 best_payload: String,
101 bypass_found: bool,
103 #[serde(default)]
105 rule_stats: HashMap<u8, (u64, f64)>,
106 #[serde(skip)]
108 in_flight: HashMap<u64, Chromosome>,
109 mcts_budget: u64,
111 ucb1_c: f64,
113 #[serde(skip)]
115 pending: Vec<(u64, Chromosome)>,
116}
117
118impl AstMctsAlgorithm {
119 #[must_use]
121 pub fn new() -> Self {
122 Self::with_config(DEFAULT_MCTS_BUDGET, DEFAULT_UCB1_C)
123 }
124
125 #[must_use]
130 pub fn with_config(mcts_budget: u64, ucb1_c: f64) -> Self {
131 Self {
132 best: Chromosome::new(vec![("ast_mcts_payload".into(), String::new())]),
133 gene_pool: GenePool::default_wafrift(),
134 generation: 0,
135 eval_counter: 0,
136 best_payload: String::new(),
137 bypass_found: false,
138 rule_stats: HashMap::new(),
139 in_flight: HashMap::new(),
140 mcts_budget,
141 ucb1_c,
142 pending: Vec::new(),
143 }
144 }
145
146 fn payload_from_chromosome(c: &Chromosome) -> &str {
149 c.gene("ast_mcts_payload")
150 .or_else(|| c.gene("payload"))
151 .unwrap_or("")
152 }
153
154 fn replenish(&mut self, n: usize, rng: &mut StdRng) {
159 if self.best_payload.is_empty() {
160 for _ in 0..n {
162 self.eval_counter = self.eval_counter.saturating_add(1);
163 let mut c = random_chromosome(&self.gene_pool, rng);
164 c.genes.push(("ast_mcts_payload".into(), String::new()));
165 c.lineage = Lineage::genesis(self.generation);
166 self.pending.push((self.eval_counter, c));
167 }
168 return;
169 }
170
171 let jitter: u64 = rng.next_u64();
173 let mut generated: Vec<String> = Vec::new();
174 let mut inline = InlineOracle {
175 candidates: &mut generated,
176 prior_bypass: self.bypass_found,
177 jitter,
178 };
179
180 let result: Option<MctsResult> = mcts_search(
181 &self.best_payload,
182 self.mcts_budget,
183 self.ucb1_c,
184 &mut inline,
185 );
186
187 if let Some(ref r) = result {
189 for &(action, visits, mean_reward) in &r.arm_stats {
190 let entry = self.rule_stats.entry(action.rule.0).or_insert((0, 0.0));
191 entry.0 = entry.0.saturating_add(visits);
192 let addend = if mean_reward.is_finite() {
196 mean_reward * (visits as f64)
197 } else {
198 0.0
199 };
200 entry.1 = if entry.1.is_finite() {
201 entry.1 + addend
202 } else {
203 addend
207 };
208 }
209 }
210
211 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
213 if let Some(ref r) = result
215 && !r.best_payload.is_empty()
216 && seen.insert(r.best_payload.clone())
217 {
218 self.eval_counter = self.eval_counter.saturating_add(1);
219 let mut c = self.best.clone();
220 let payload = r.best_payload.clone();
221 set_gene(&mut c, "ast_mcts_payload", &payload);
222 c.lineage = Lineage::mutation(
223 &self.best,
224 vec![crate::lineage::MutationOp {
225 gene_name: "ast_mcts_payload".into(),
226 from: self.best_payload.clone(),
227 to: payload.clone(),
228 operator: "ast_mcts:best_payload".into(),
229 }],
230 self.generation,
231 );
232 self.pending.push((self.eval_counter, c));
233 }
234
235 for payload in generated {
237 if self.pending.len() >= n {
238 break;
239 }
240 if payload.is_empty() || !seen.insert(payload.clone()) {
241 continue;
242 }
243 self.eval_counter = self.eval_counter.saturating_add(1);
244 let mut c = self.best.clone();
245 set_gene(&mut c, "ast_mcts_payload", &payload);
246 c.lineage = Lineage::mutation(
247 &self.best,
248 vec![crate::lineage::MutationOp {
249 gene_name: "ast_mcts_payload".into(),
250 from: self.best_payload.clone(),
251 to: payload.clone(),
252 operator: "ast_mcts:inline_candidate".into(),
253 }],
254 self.generation,
255 );
256 self.pending.push((self.eval_counter, c));
257 }
258
259 if self.pending.is_empty() {
261 self.eval_counter = self.eval_counter.saturating_add(1);
262 let mut c = self.best.clone();
263 set_gene(&mut c, "ast_mcts_payload", &self.best_payload);
264 c.lineage = Lineage::genesis(self.generation);
265 self.pending.push((self.eval_counter, c));
266 }
267 }
268}
269
270impl Default for AstMctsAlgorithm {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276fn set_gene(c: &mut Chromosome, name: &str, value: &str) {
278 if let Some(entry) = c.genes.iter_mut().find(|(k, _)| k == name) {
279 entry.1 = value.to_string();
280 } else {
281 c.genes.push((name.to_string(), value.to_string()));
282 }
283}
284
285impl SearchAlgorithm for AstMctsAlgorithm {
286 fn name(&self) -> &'static str {
287 "ast_mcts"
288 }
289
290 fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
291 self.gene_pool = gene_pool.clone();
292 self.generation = 0;
293 self.eval_counter = 0;
294 self.bypass_found = false;
295 self.pending.clear();
296 self.in_flight.clear();
297
298 if let Some(seed) = population
300 .into_iter()
301 .max_by(|a, b| fitness_cmp(a.fitness, b.fitness))
302 {
303 let payload = Self::payload_from_chromosome(&seed).to_string();
304 self.best_payload = payload;
305 self.best = seed;
306 }
307 set_gene(&mut self.best, "ast_mcts_payload", &self.best_payload);
309 }
310
311 fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
312 if n == 0 {
313 return Vec::new();
314 }
315 if self.pending.is_empty() {
317 self.replenish(n, rng);
318 }
319
320 let drain_count = n.min(self.pending.len());
322 let batch: Vec<(u64, Chromosome)> = self.pending.drain(..drain_count).collect();
323
324 let mut out = Vec::with_capacity(batch.len());
325 for (id, chromosome) in batch {
326 self.in_flight.insert(id, chromosome.clone());
327 out.push(EvalCandidate { id, chromosome });
328 }
329 out
330 }
331
332 fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
333 for (id, verdict) in results {
334 let Some(mut chromosome) = self.in_flight.remove(&id) else {
335 continue;
336 };
337 chromosome.record_verdict(&verdict);
338
339 if verdict.passed || chromosome.fitness > self.best.fitness {
341 if verdict.passed && !self.bypass_found {
342 self.bypass_found = true;
343 }
344 let new_payload = chromosome
345 .gene("ast_mcts_payload")
346 .unwrap_or("")
347 .to_string();
348 if !new_payload.is_empty() {
349 self.best_payload = new_payload;
350 }
351 self.best = chromosome;
352 }
353 }
354 self.generation = self.generation.saturating_add(1);
355 }
356
357 fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
358 self.bypass_found
359 || stats.evaluations >= budget.max_requests
360 || stats.generation >= budget.max_generations
361 || stats.stagnation_counter >= budget.stagnation_limit
362 }
363
364 fn best(&self) -> Option<&Chromosome> {
365 Some(&self.best)
366 }
367
368 fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
369 serde_json::to_vec(self).map_err(EvolutionError::SerializationFailed)
370 }
371
372 fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
373 if bytes.len() > crate::types::MAX_CHECKPOINT_BYTES {
374 return Err(EvolutionError::OversizedData {
375 context: "ast_mcts checkpoint restore".into(),
376 size: bytes.len(),
377 max: crate::types::MAX_CHECKPOINT_BYTES,
378 });
379 }
380 *self = serde_json::from_slice(bytes).map_err(EvolutionError::DeserializationFailed)?;
381 Ok(())
382 }
383
384 fn clone_box(&self) -> Box<dyn SearchAlgorithm> {
385 Box::new(self.clone())
386 }
387
388 fn population_snapshot(&self) -> Vec<Chromosome> {
389 vec![self.best.clone()]
390 }
391}
392
393#[must_use]
395pub fn all_rule_names() -> Vec<&'static str> {
396 RuleId::ALL.iter().map(|r| r.name()).collect()
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use rand::SeedableRng;
403
404 fn make_rng() -> StdRng {
405 StdRng::seed_from_u64(0x00C0_FFEE_BABE)
406 }
407
408 #[test]
409 fn name_is_ast_mcts() {
410 assert_eq!(AstMctsAlgorithm::new().name(), "ast_mcts");
411 }
412
413 #[test]
414 fn initialize_with_empty_population_sets_empty_best_payload() {
415 let mut alg = AstMctsAlgorithm::new();
416 let pool = GenePool::default_wafrift();
417 let mut rng = make_rng();
418 alg.initialize(vec![], &pool, &mut rng);
419 assert!(alg.best_payload.is_empty());
420 }
421
422 #[test]
423 fn initialize_with_sql_payload_captures_it() {
424 let mut alg = AstMctsAlgorithm::new();
425 let pool = GenePool::default_wafrift();
426 let mut rng = make_rng();
427 let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "'a'='a'".into())]);
428 alg.initialize(vec![seed], &pool, &mut rng);
429 assert_eq!(alg.best_payload, "'a'='a'");
430 }
431
432 #[test]
433 fn request_evaluations_returns_n_candidates() {
434 let mut alg = AstMctsAlgorithm::new();
435 let pool = GenePool::default_wafrift();
436 let mut rng = make_rng();
437 let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
438 alg.initialize(vec![seed], &pool, &mut rng);
439 let candidates = alg.request_evaluations(4, &mut rng);
440 assert!(!candidates.is_empty(), "must return at least one candidate");
442 assert!(candidates.len() <= 4);
443 }
444
445 #[test]
446 fn request_evaluations_n_zero_returns_empty() {
447 let mut alg = AstMctsAlgorithm::new();
448 let pool = GenePool::default_wafrift();
449 let mut rng = make_rng();
450 alg.initialize(vec![], &pool, &mut rng);
451 let out = alg.request_evaluations(0, &mut rng);
452 assert!(out.is_empty());
453 }
454
455 #[test]
456 fn submit_evaluations_updates_best_on_pass() {
457 let mut alg = AstMctsAlgorithm::new();
458 let pool = GenePool::default_wafrift();
459 let mut rng = make_rng();
460 let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
461 alg.initialize(vec![seed], &pool, &mut rng);
462
463 let candidates = alg.request_evaluations(3, &mut rng);
464 let first = candidates.into_iter().next().unwrap();
465 let first_payload = first
466 .chromosome
467 .gene("ast_mcts_payload")
468 .unwrap_or("")
469 .to_string();
470
471 let verdict = OracleVerdict::from_bool(true);
473 alg.submit_evaluations(vec![(first.id, verdict)]);
474
475 assert!(alg.bypass_found, "bypass_found must be set after a pass");
476 assert_eq!(alg.best_payload, first_payload);
477 }
478
479 #[test]
480 fn should_terminate_on_bypass() {
481 let mut alg = AstMctsAlgorithm::new();
482 let pool = GenePool::default_wafrift();
483 let mut rng = make_rng();
484 alg.initialize(vec![], &pool, &mut rng);
485 alg.bypass_found = true;
486 let stats = SearchStats::new();
487 let budget = Budget::default();
488 assert!(alg.should_terminate(&stats, &budget));
489 }
490
491 #[test]
492 fn checkpoint_roundtrip_preserves_state() {
493 let mut alg = AstMctsAlgorithm::new();
494 let pool = GenePool::default_wafrift();
495 let mut rng = make_rng();
496 let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "'x'='x'".into())]);
497 alg.initialize(vec![seed], &pool, &mut rng);
498 alg.bypass_found = true;
499
500 let bytes = alg.checkpoint().unwrap();
501 let mut restored = AstMctsAlgorithm::new();
502 restored.restore(&bytes).unwrap();
503
504 assert_eq!(restored.best_payload, alg.best_payload);
505 assert_eq!(restored.bypass_found, alg.bypass_found);
506 }
507
508 #[test]
509 fn clone_box_produces_independent_instance() {
510 let mut alg = AstMctsAlgorithm::new();
511 let pool = GenePool::default_wafrift();
512 let mut rng = make_rng();
513 let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
514 alg.initialize(vec![seed], &pool, &mut rng);
515
516 let cloned = alg.clone_box();
517 alg.bypass_found = true;
519 assert!(!cloned.best().unwrap().has_gene("non_existent"));
520 let _ = cloned.best();
522 }
523
524 #[test]
525 fn all_rule_names_covers_all_16_rules() {
526 let names = all_rule_names();
527 assert_eq!(names.len(), 16, "all 16 RuleId variants must be named");
528 }
529
530 #[test]
531 fn population_snapshot_returns_best() {
532 let alg = AstMctsAlgorithm::new();
533 let snap = alg.population_snapshot();
534 assert_eq!(snap.len(), 1);
535 }
536
537 #[test]
543 fn eval_counter_saturates_at_u64_max() {
544 let mut alg = AstMctsAlgorithm::new();
545 let pool = GenePool::default_wafrift();
546 let mut rng = make_rng();
547 alg.initialize(
548 vec![Chromosome::new(vec![(
549 "ast_mcts_payload".into(),
550 "1=1".into(),
551 )])],
552 &pool,
553 &mut rng,
554 );
555 alg.eval_counter = u64::MAX;
556 let _ = alg.request_evaluations(1, &mut rng);
558 assert_eq!(
559 alg.eval_counter,
560 u64::MAX,
561 "eval_counter must saturate at u64::MAX, not wrap to 0"
562 );
563 }
564
565 #[test]
567 fn generation_saturates_at_u32_max() {
568 let mut alg = AstMctsAlgorithm::new();
569 let pool = GenePool::default_wafrift();
570 let mut rng = make_rng();
571 alg.initialize(vec![], &pool, &mut rng);
572 alg.generation = u32::MAX;
573 alg.submit_evaluations(vec![(0, OracleVerdict::from_bool(false))]);
575 assert_eq!(
576 alg.generation,
577 u32::MAX,
578 "generation must saturate at u32::MAX, not wrap to 0"
579 );
580 }
581
582 #[test]
586 fn rule_stats_nan_reward_does_not_poison_ucb1() {
587 let mut alg = AstMctsAlgorithm::new();
588 let pool = GenePool::default_wafrift();
589 let mut rng = make_rng();
590 alg.initialize(
591 vec![Chromosome::new(vec![(
592 "ast_mcts_payload".into(),
593 "1=1".into(),
594 )])],
595 &pool,
596 &mut rng,
597 );
598
599 alg.rule_stats.insert(0, (10, f64::NAN));
601
602 let candidates = alg.request_evaluations(2, &mut rng);
604 if let Some(c) = candidates.into_iter().next() {
605 alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(true))]);
606 }
607
608 for (visits, total) in alg.rule_stats.values() {
610 assert!(
611 total.is_finite() || *visits == 0,
612 "rule_stats total must be finite after NaN reset, got {total}"
613 );
614 }
615 }
616
617 #[test]
619 fn rule_stats_inf_reward_does_not_poison_ucb1() {
620 let mut alg = AstMctsAlgorithm::new();
621 let pool = GenePool::default_wafrift();
622 let mut rng = make_rng();
623 alg.initialize(
624 vec![Chromosome::new(vec![(
625 "ast_mcts_payload".into(),
626 "1=1".into(),
627 )])],
628 &pool,
629 &mut rng,
630 );
631
632 alg.rule_stats.insert(1, (5, f64::INFINITY));
633
634 let candidates = alg.request_evaluations(2, &mut rng);
635 if let Some(c) = candidates.into_iter().next() {
636 alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(false))]);
637 }
638
639 for (visits, total) in alg.rule_stats.values() {
640 assert!(
641 total.is_finite() || *visits == 0,
642 "rule_stats total must be finite after Inf reset, got {total}"
643 );
644 }
645 }
646
647 #[test]
650 fn rule_stats_nan_does_not_cross_contaminate_other_rules() {
651 let mut alg = AstMctsAlgorithm::new();
652 let pool = GenePool::default_wafrift();
653 let mut rng = make_rng();
654 alg.initialize(
655 vec![Chromosome::new(vec![(
656 "ast_mcts_payload".into(),
657 "1=1".into(),
658 )])],
659 &pool,
660 &mut rng,
661 );
662
663 alg.rule_stats.insert(0, (3, 2.5));
665 alg.rule_stats.insert(1, (7, f64::NAN));
666
667 let candidates = alg.request_evaluations(1, &mut rng);
669 if let Some(c) = candidates.into_iter().next() {
670 alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(true))]);
671 }
672
673 if let Some((_, total)) = alg.rule_stats.get(&0) {
675 assert!(
676 total.is_finite(),
677 "healthy rule_stats entry must remain finite, got {total}"
678 );
679 }
680 }
681}