Skip to main content

wafrift_evolution/search/
ast_mcts_algorithm.rs

1//! AST-MCTS [`SearchAlgorithm`] adapter.
2//!
3//! Bridges [`crate::ast_mcts::mcts_search`] — which operates on raw SQL
4//! payload strings — into the [`SearchAlgorithm`] trait so
5//! [`crate::evolution::EvolutionEngine`] can select it alongside
6//! MAP-Elites, Novelty, UCB1, etc.
7//!
8//! # Design
9//!
10//! Each call to `request_evaluations` runs MCTS internally against an
11//! inline oracle that replays the last-known blocked/passed signal from
12//! `submit_evaluations`. Because MCTS already expends its oracle budget
13//! internally, the produced `EvalCandidate`s each carry a *different*
14//! AST-rewritten payload (one per rule × position arm), exposing them
15//! to the engine's external oracle for final verification.
16//!
17//! The chromosome's `ast_mcts_payload` gene carries the rewritten SQL
18//! fragment; other genes are inherited from the population seed so the
19//! engine's gene-success stats continue working across mutator modes.
20//!
21//! # Determinism
22//!
23//! Per-run determinism is achieved by seeding the MCTS UCB1 exploration
24//! with the engine's `StdRng` (passed in via `request_evaluations`'s
25//! `&mut StdRng` argument).  The same seed → same evaluation sequence →
26//! same payload distribution. Verified by `tests/ast_mcts_wiring.rs`.
27
28use crate::ast_mcts::{AstMctsOracle, MctsResult, RuleId, mcts_search};
29use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
30use crate::lineage::Lineage;
31use crate::search::{EvalCandidate, SearchAlgorithm, fitness_cmp};
32use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
33use rand::RngCore;
34use rand::rngs::StdRng;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38/// Default number of oracle queries per MCTS run.
39pub const DEFAULT_MCTS_BUDGET: u64 = 64;
40
41/// Default UCB1 exploration constant (sqrt(2), per the AdvSQLi paper).
42pub const DEFAULT_UCB1_C: f64 = std::f64::consts::SQRT_2;
43
44/// Inline oracle used during the MCTS rollout phase.
45///
46/// Records each candidate it is asked to evaluate so the caller can
47/// later surface them as `EvalCandidate`s to the external oracle.
48/// Returns `true` (blocked) by default until an external bypass signal
49/// has been received, then toggles probabilistically based on the
50/// per-rule UCB1 statistics from the previous round.
51struct InlineOracle<'a> {
52    /// Payloads generated during this rollout, in evaluation order.
53    candidates: &'a mut Vec<String>,
54    /// Whether a bypass was seen in the previous round (seed signal).
55    prior_bypass: bool,
56    /// Pseudo-random jitter source so repeated arms don't collapse.
57    jitter: u64,
58}
59
60impl<'a> AstMctsOracle for InlineOracle<'a> {
61    fn eval(&mut self, candidate: &str) -> bool {
62        self.candidates.push(candidate.to_string());
63        self.jitter = self
64            .jitter
65            .wrapping_mul(6_364_136_223_846_793_005)
66            .wrapping_add(1_442_695_040_888_963_407);
67        if self.prior_bypass {
68            // If we've seen a bypass before, treat even new arms as "blocked"
69            // so MCTS keeps exploring (we're in ablation mode, not live-fire).
70            true
71        } else {
72            // With no prior signal, everything is treated as blocked — the
73            // external oracle provides the true signal after the batch.
74            true
75        }
76    }
77}
78
79/// A `SearchAlgorithm` that uses AST-MCTS over SQL/XSS payload fragments.
80///
81/// On each `request_evaluations` call, the algorithm:
82/// 1. Picks the current best payload (from the seed or last bypass).
83/// 2. Runs `mcts_search` with an inline oracle to enumerate candidate rewrites.
84/// 3. Wraps each rewrite as a `Chromosome` with an `ast_mcts_payload` gene.
85/// 4. Returns up to `n` candidates for external oracle verification.
86///
87/// `submit_evaluations` updates the best-known payload whenever a bypass
88/// (passed == true) is received.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct AstMctsAlgorithm {
91    /// Current best chromosome (highest fitness seen so far).
92    best: Chromosome,
93    /// Gene pool for generating seed chromosomes when no population is provided.
94    gene_pool: GenePool,
95    /// Generation counter.
96    generation: u32,
97    /// Monotonic evaluation ID counter.
98    eval_counter: u64,
99    /// Payload fragment associated with the best chromosome.
100    best_payload: String,
101    /// Whether the best payload has been confirmed as a bypass.
102    bypass_found: bool,
103    /// Per-RuleId UCB1 statistics carried across rounds: (visits, total_reward).
104    #[serde(default)]
105    rule_stats: HashMap<u8, (u64, f64)>,
106    /// In-flight map: eval_id → chromosome.
107    #[serde(skip)]
108    in_flight: HashMap<u64, Chromosome>,
109    /// Budget of oracle queries per MCTS round.
110    mcts_budget: u64,
111    /// UCB1 exploration constant.
112    ucb1_c: f64,
113    /// Pending candidates: pre-generated payloads waiting to be dispatched.
114    #[serde(skip)]
115    pending: Vec<(u64, Chromosome)>,
116}
117
118impl AstMctsAlgorithm {
119    /// Create a new instance with default MCTS budget and UCB1 constant.
120    #[must_use]
121    pub fn new() -> Self {
122        Self::with_config(DEFAULT_MCTS_BUDGET, DEFAULT_UCB1_C)
123    }
124
125    /// Create with explicit MCTS budget and UCB1 constant.
126    ///
127    /// - `mcts_budget`: oracle queries spent per round inside MCTS.
128    /// - `ucb1_c`: exploration constant; `sqrt(2)` is the AdvSQLi default.
129    #[must_use]
130    pub fn with_config(mcts_budget: u64, ucb1_c: f64) -> Self {
131        Self {
132            best: Chromosome::new(vec![("ast_mcts_payload".into(), String::new())]),
133            gene_pool: GenePool::default_wafrift(),
134            generation: 0,
135            eval_counter: 0,
136            best_payload: String::new(),
137            bypass_found: false,
138            rule_stats: HashMap::new(),
139            in_flight: HashMap::new(),
140            mcts_budget,
141            ucb1_c,
142            pending: Vec::new(),
143        }
144    }
145
146    /// Extract the SQL payload from a chromosome's `ast_mcts_payload` gene,
147    /// falling back to its `payload` gene, and ultimately to an empty string.
148    fn payload_from_chromosome(c: &Chromosome) -> &str {
149        c.gene("ast_mcts_payload")
150            .or_else(|| c.gene("payload"))
151            .unwrap_or("")
152    }
153
154    /// Run one MCTS round and populate `self.pending` with new candidates.
155    ///
156    /// If the current best payload is empty (no seed yet), emits a single
157    /// baseline chromosome with an empty payload so the engine can warm-start.
158    fn replenish(&mut self, n: usize, rng: &mut StdRng) {
159        if self.best_payload.is_empty() {
160            // No payload yet — emit baseline chromosomes drawn from gene pool.
161            for _ in 0..n {
162                self.eval_counter = self.eval_counter.saturating_add(1);
163                let mut c = random_chromosome(&self.gene_pool, rng);
164                c.genes.push(("ast_mcts_payload".into(), String::new()));
165                c.lineage = Lineage::genesis(self.generation);
166                self.pending.push((self.eval_counter, c));
167            }
168            return;
169        }
170
171        // Run MCTS using an inline oracle to enumerate candidate rewrites.
172        let jitter: u64 = rng.next_u64();
173        let mut generated: Vec<String> = Vec::new();
174        let mut inline = InlineOracle {
175            candidates: &mut generated,
176            prior_bypass: self.bypass_found,
177            jitter,
178        };
179
180        let result: Option<MctsResult> = mcts_search(
181            &self.best_payload,
182            self.mcts_budget,
183            self.ucb1_c,
184            &mut inline,
185        );
186
187        // Absorb arm stats for cross-round learning.
188        if let Some(ref r) = result {
189            for &(action, visits, mean_reward) in &r.arm_stats {
190                let entry = self.rule_stats.entry(action.rule.0).or_insert((0, 0.0));
191                entry.0 = entry.0.saturating_add(visits);
192                // Guard against non-finite mean_reward to prevent Inf/NaN
193                // accumulation in the running total. visits is u64 cast to f64;
194                // above 2^53 the cast loses precision but cannot produce NaN/Inf.
195                let addend = if mean_reward.is_finite() {
196                    mean_reward * (visits as f64)
197                } else {
198                    0.0
199                };
200                entry.1 = if entry.1.is_finite() {
201                    entry.1 + addend
202                } else {
203                    // The running total somehow became non-finite (adversarial
204                    // oracle, upstream bug). Reset to the current observation
205                    // rather than propagating the poison.
206                    addend
207                };
208            }
209        }
210
211        // Deduplicate generated payloads; prefer best_payload candidates first.
212        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
213        // Always include the MCTS-best payload if it produced one.
214        if let Some(ref r) = result
215            && !r.best_payload.is_empty()
216            && seen.insert(r.best_payload.clone())
217        {
218            self.eval_counter = self.eval_counter.saturating_add(1);
219            let mut c = self.best.clone();
220            let payload = r.best_payload.clone();
221            set_gene(&mut c, "ast_mcts_payload", &payload);
222            c.lineage = Lineage::mutation(
223                &self.best,
224                vec![crate::lineage::MutationOp {
225                    gene_name: "ast_mcts_payload".into(),
226                    from: self.best_payload.clone(),
227                    to: payload.clone(),
228                    operator: "ast_mcts:best_payload".into(),
229                }],
230                self.generation,
231            );
232            self.pending.push((self.eval_counter, c));
233        }
234
235        // Then include other generated candidates up to n.
236        for payload in generated {
237            if self.pending.len() >= n {
238                break;
239            }
240            if payload.is_empty() || !seen.insert(payload.clone()) {
241                continue;
242            }
243            self.eval_counter = self.eval_counter.saturating_add(1);
244            let mut c = self.best.clone();
245            set_gene(&mut c, "ast_mcts_payload", &payload);
246            c.lineage = Lineage::mutation(
247                &self.best,
248                vec![crate::lineage::MutationOp {
249                    gene_name: "ast_mcts_payload".into(),
250                    from: self.best_payload.clone(),
251                    to: payload.clone(),
252                    operator: "ast_mcts:inline_candidate".into(),
253                }],
254                self.generation,
255            );
256            self.pending.push((self.eval_counter, c));
257        }
258
259        // If MCTS produced nothing useful, emit the original payload as a fallback.
260        if self.pending.is_empty() {
261            self.eval_counter = self.eval_counter.saturating_add(1);
262            let mut c = self.best.clone();
263            set_gene(&mut c, "ast_mcts_payload", &self.best_payload);
264            c.lineage = Lineage::genesis(self.generation);
265            self.pending.push((self.eval_counter, c));
266        }
267    }
268}
269
270impl Default for AstMctsAlgorithm {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276/// Update or insert a gene in a chromosome's gene list.
277fn set_gene(c: &mut Chromosome, name: &str, value: &str) {
278    if let Some(entry) = c.genes.iter_mut().find(|(k, _)| k == name) {
279        entry.1 = value.to_string();
280    } else {
281        c.genes.push((name.to_string(), value.to_string()));
282    }
283}
284
285impl SearchAlgorithm for AstMctsAlgorithm {
286    fn name(&self) -> &'static str {
287        "ast_mcts"
288    }
289
290    fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
291        self.gene_pool = gene_pool.clone();
292        self.generation = 0;
293        self.eval_counter = 0;
294        self.bypass_found = false;
295        self.pending.clear();
296        self.in_flight.clear();
297
298        // Pick the highest-fitness seed from the provided population.
299        if let Some(seed) = population
300            .into_iter()
301            .max_by(|a, b| fitness_cmp(a.fitness, b.fitness))
302        {
303            let payload = Self::payload_from_chromosome(&seed).to_string();
304            self.best_payload = payload;
305            self.best = seed;
306        }
307        // Ensure the best chromosome always has the ast_mcts_payload gene.
308        set_gene(&mut self.best, "ast_mcts_payload", &self.best_payload);
309    }
310
311    fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
312        if n == 0 {
313            return Vec::new();
314        }
315        // Fill pending if empty.
316        if self.pending.is_empty() {
317            self.replenish(n, rng);
318        }
319
320        // Drain up to n from pending.
321        let drain_count = n.min(self.pending.len());
322        let batch: Vec<(u64, Chromosome)> = self.pending.drain(..drain_count).collect();
323
324        let mut out = Vec::with_capacity(batch.len());
325        for (id, chromosome) in batch {
326            self.in_flight.insert(id, chromosome.clone());
327            out.push(EvalCandidate { id, chromosome });
328        }
329        out
330    }
331
332    fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
333        for (id, verdict) in results {
334            let Some(mut chromosome) = self.in_flight.remove(&id) else {
335                continue;
336            };
337            chromosome.record_verdict(&verdict);
338
339            // Update best on improvement.
340            if verdict.passed || chromosome.fitness > self.best.fitness {
341                if verdict.passed && !self.bypass_found {
342                    self.bypass_found = true;
343                }
344                let new_payload = chromosome
345                    .gene("ast_mcts_payload")
346                    .unwrap_or("")
347                    .to_string();
348                if !new_payload.is_empty() {
349                    self.best_payload = new_payload;
350                }
351                self.best = chromosome;
352            }
353        }
354        self.generation = self.generation.saturating_add(1);
355    }
356
357    fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
358        self.bypass_found
359            || stats.evaluations >= budget.max_requests
360            || stats.generation >= budget.max_generations
361            || stats.stagnation_counter >= budget.stagnation_limit
362    }
363
364    fn best(&self) -> Option<&Chromosome> {
365        Some(&self.best)
366    }
367
368    fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
369        serde_json::to_vec(self).map_err(EvolutionError::SerializationFailed)
370    }
371
372    fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
373        if bytes.len() > crate::types::MAX_CHECKPOINT_BYTES {
374            return Err(EvolutionError::OversizedData {
375                context: "ast_mcts checkpoint restore".into(),
376                size: bytes.len(),
377                max: crate::types::MAX_CHECKPOINT_BYTES,
378            });
379        }
380        *self = serde_json::from_slice(bytes).map_err(EvolutionError::DeserializationFailed)?;
381        Ok(())
382    }
383
384    fn clone_box(&self) -> Box<dyn SearchAlgorithm> {
385        Box::new(self.clone())
386    }
387
388    fn population_snapshot(&self) -> Vec<Chromosome> {
389        vec![self.best.clone()]
390    }
391}
392
393/// Convenience: the rule names that AST-MCTS uses, suitable for reporting.
394#[must_use]
395pub fn all_rule_names() -> Vec<&'static str> {
396    RuleId::ALL.iter().map(|r| r.name()).collect()
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use rand::SeedableRng;
403
404    fn make_rng() -> StdRng {
405        StdRng::seed_from_u64(0x00C0_FFEE_BABE)
406    }
407
408    #[test]
409    fn name_is_ast_mcts() {
410        assert_eq!(AstMctsAlgorithm::new().name(), "ast_mcts");
411    }
412
413    #[test]
414    fn initialize_with_empty_population_sets_empty_best_payload() {
415        let mut alg = AstMctsAlgorithm::new();
416        let pool = GenePool::default_wafrift();
417        let mut rng = make_rng();
418        alg.initialize(vec![], &pool, &mut rng);
419        assert!(alg.best_payload.is_empty());
420    }
421
422    #[test]
423    fn initialize_with_sql_payload_captures_it() {
424        let mut alg = AstMctsAlgorithm::new();
425        let pool = GenePool::default_wafrift();
426        let mut rng = make_rng();
427        let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "'a'='a'".into())]);
428        alg.initialize(vec![seed], &pool, &mut rng);
429        assert_eq!(alg.best_payload, "'a'='a'");
430    }
431
432    #[test]
433    fn request_evaluations_returns_n_candidates() {
434        let mut alg = AstMctsAlgorithm::new();
435        let pool = GenePool::default_wafrift();
436        let mut rng = make_rng();
437        let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
438        alg.initialize(vec![seed], &pool, &mut rng);
439        let candidates = alg.request_evaluations(4, &mut rng);
440        // May return fewer than 4 if MCTS produces fewer distinct rewrites.
441        assert!(!candidates.is_empty(), "must return at least one candidate");
442        assert!(candidates.len() <= 4);
443    }
444
445    #[test]
446    fn request_evaluations_n_zero_returns_empty() {
447        let mut alg = AstMctsAlgorithm::new();
448        let pool = GenePool::default_wafrift();
449        let mut rng = make_rng();
450        alg.initialize(vec![], &pool, &mut rng);
451        let out = alg.request_evaluations(0, &mut rng);
452        assert!(out.is_empty());
453    }
454
455    #[test]
456    fn submit_evaluations_updates_best_on_pass() {
457        let mut alg = AstMctsAlgorithm::new();
458        let pool = GenePool::default_wafrift();
459        let mut rng = make_rng();
460        let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
461        alg.initialize(vec![seed], &pool, &mut rng);
462
463        let candidates = alg.request_evaluations(3, &mut rng);
464        let first = candidates.into_iter().next().unwrap();
465        let first_payload = first
466            .chromosome
467            .gene("ast_mcts_payload")
468            .unwrap_or("")
469            .to_string();
470
471        // Simulate a bypass verdict.
472        let verdict = OracleVerdict::from_bool(true);
473        alg.submit_evaluations(vec![(first.id, verdict)]);
474
475        assert!(alg.bypass_found, "bypass_found must be set after a pass");
476        assert_eq!(alg.best_payload, first_payload);
477    }
478
479    #[test]
480    fn should_terminate_on_bypass() {
481        let mut alg = AstMctsAlgorithm::new();
482        let pool = GenePool::default_wafrift();
483        let mut rng = make_rng();
484        alg.initialize(vec![], &pool, &mut rng);
485        alg.bypass_found = true;
486        let stats = SearchStats::new();
487        let budget = Budget::default();
488        assert!(alg.should_terminate(&stats, &budget));
489    }
490
491    #[test]
492    fn checkpoint_roundtrip_preserves_state() {
493        let mut alg = AstMctsAlgorithm::new();
494        let pool = GenePool::default_wafrift();
495        let mut rng = make_rng();
496        let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "'x'='x'".into())]);
497        alg.initialize(vec![seed], &pool, &mut rng);
498        alg.bypass_found = true;
499
500        let bytes = alg.checkpoint().unwrap();
501        let mut restored = AstMctsAlgorithm::new();
502        restored.restore(&bytes).unwrap();
503
504        assert_eq!(restored.best_payload, alg.best_payload);
505        assert_eq!(restored.bypass_found, alg.bypass_found);
506    }
507
508    #[test]
509    fn clone_box_produces_independent_instance() {
510        let mut alg = AstMctsAlgorithm::new();
511        let pool = GenePool::default_wafrift();
512        let mut rng = make_rng();
513        let seed = Chromosome::new(vec![("ast_mcts_payload".into(), "1=1".into())]);
514        alg.initialize(vec![seed], &pool, &mut rng);
515
516        let cloned = alg.clone_box();
517        // Mutate clone — original must not change.
518        alg.bypass_found = true;
519        assert!(!cloned.best().unwrap().has_gene("non_existent"));
520        // Clone's bypass state tracks independently.
521        let _ = cloned.best();
522    }
523
524    #[test]
525    fn all_rule_names_covers_all_16_rules() {
526        let names = all_rule_names();
527        assert_eq!(names.len(), 16, "all 16 RuleId variants must be named");
528    }
529
530    #[test]
531    fn population_snapshot_returns_best() {
532        let alg = AstMctsAlgorithm::new();
533        let snap = alg.population_snapshot();
534        assert_eq!(snap.len(), 1);
535    }
536
537    // ── Saturating-arithmetic + NaN/Inf regression tests ─────────────────────
538
539    /// `eval_counter` must saturate at `u64::MAX` rather than wrapping to 0.
540    /// A wrap-around would reuse previously-issued IDs, causing the engine's
541    /// `in_flight` map to collide and silently drop evaluations.
542    #[test]
543    fn eval_counter_saturates_at_u64_max() {
544        let mut alg = AstMctsAlgorithm::new();
545        let pool = GenePool::default_wafrift();
546        let mut rng = make_rng();
547        alg.initialize(
548            vec![Chromosome::new(vec![(
549                "ast_mcts_payload".into(),
550                "1=1".into(),
551            )])],
552            &pool,
553            &mut rng,
554        );
555        alg.eval_counter = u64::MAX;
556        // request_evaluations calls saturating_add — counter must stay at MAX.
557        let _ = alg.request_evaluations(1, &mut rng);
558        assert_eq!(
559            alg.eval_counter,
560            u64::MAX,
561            "eval_counter must saturate at u64::MAX, not wrap to 0"
562        );
563    }
564
565    /// `generation` must saturate at `u32::MAX` rather than wrapping.
566    #[test]
567    fn generation_saturates_at_u32_max() {
568        let mut alg = AstMctsAlgorithm::new();
569        let pool = GenePool::default_wafrift();
570        let mut rng = make_rng();
571        alg.initialize(vec![], &pool, &mut rng);
572        alg.generation = u32::MAX;
573        // submit_evaluations increments generation.
574        alg.submit_evaluations(vec![(0, OracleVerdict::from_bool(false))]);
575        assert_eq!(
576            alg.generation,
577            u32::MAX,
578            "generation must saturate at u32::MAX, not wrap to 0"
579        );
580    }
581
582    /// A NaN `mean_reward` from the oracle must NOT permanently poison the
583    /// running `rule_stats` total.  After the NaN injection, a valid reward
584    /// must still produce a finite and positive running total.
585    #[test]
586    fn rule_stats_nan_reward_does_not_poison_ucb1() {
587        let mut alg = AstMctsAlgorithm::new();
588        let pool = GenePool::default_wafrift();
589        let mut rng = make_rng();
590        alg.initialize(
591            vec![Chromosome::new(vec![(
592                "ast_mcts_payload".into(),
593                "1=1".into(),
594            )])],
595            &pool,
596            &mut rng,
597        );
598
599        // Manually inject NaN into rule_stats (simulating a buggy oracle).
600        alg.rule_stats.insert(0, (10, f64::NAN));
601
602        // Submit a valid passing verdict — the NaN total must be cleared.
603        let candidates = alg.request_evaluations(2, &mut rng);
604        if let Some(c) = candidates.into_iter().next() {
605            alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(true))]);
606        }
607
608        // The rule_stats entry for rule 0 must now hold a finite total.
609        for (visits, total) in alg.rule_stats.values() {
610            assert!(
611                total.is_finite() || *visits == 0,
612                "rule_stats total must be finite after NaN reset, got {total}"
613            );
614        }
615    }
616
617    /// `+Inf` in the running total must also be cleared (same guard).
618    #[test]
619    fn rule_stats_inf_reward_does_not_poison_ucb1() {
620        let mut alg = AstMctsAlgorithm::new();
621        let pool = GenePool::default_wafrift();
622        let mut rng = make_rng();
623        alg.initialize(
624            vec![Chromosome::new(vec![(
625                "ast_mcts_payload".into(),
626                "1=1".into(),
627            )])],
628            &pool,
629            &mut rng,
630        );
631
632        alg.rule_stats.insert(1, (5, f64::INFINITY));
633
634        let candidates = alg.request_evaluations(2, &mut rng);
635        if let Some(c) = candidates.into_iter().next() {
636            alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(false))]);
637        }
638
639        for (visits, total) in alg.rule_stats.values() {
640            assert!(
641                total.is_finite() || *visits == 0,
642                "rule_stats total must be finite after Inf reset, got {total}"
643            );
644        }
645    }
646
647    /// A NaN `mean_reward` from a single oracle call must not affect a
648    /// *different* rule's stats entry — the guard is per-entry.
649    #[test]
650    fn rule_stats_nan_does_not_cross_contaminate_other_rules() {
651        let mut alg = AstMctsAlgorithm::new();
652        let pool = GenePool::default_wafrift();
653        let mut rng = make_rng();
654        alg.initialize(
655            vec![Chromosome::new(vec![(
656                "ast_mcts_payload".into(),
657                "1=1".into(),
658            )])],
659            &pool,
660            &mut rng,
661        );
662
663        // Rule 0: healthy entry; rule 1: NaN-poisoned.
664        alg.rule_stats.insert(0, (3, 2.5));
665        alg.rule_stats.insert(1, (7, f64::NAN));
666
667        // Trigger a submit that might update stats.
668        let candidates = alg.request_evaluations(1, &mut rng);
669        if let Some(c) = candidates.into_iter().next() {
670            alg.submit_evaluations(vec![(c.id, OracleVerdict::from_bool(true))]);
671        }
672
673        // Rule 0's total must still be finite (may have grown from the new award).
674        if let Some((_, total)) = alg.rule_stats.get(&0) {
675            assert!(
676                total.is_finite(),
677                "healthy rule_stats entry must remain finite, got {total}"
678            );
679        }
680    }
681}