Skip to main content

ruqu_core/
confidence.rs

1//! Confidence bounds, statistical tests, and convergence utilities for
2//! quantum measurement analysis.
3//!
4//! This module provides tools for reasoning about the statistical quality of
5//! shot-based quantum simulation results, including confidence intervals for
6//! binomial proportions, expectation values, shot budget estimation, distribution
7//! distance metrics, goodness-of-fit tests, and convergence monitoring.
8
9use std::collections::HashMap;
10
11// ---------------------------------------------------------------------------
12// Core types
13// ---------------------------------------------------------------------------
14
15/// A confidence interval around a point estimate.
16#[derive(Debug, Clone)]
17pub struct ConfidenceInterval {
18    /// Lower bound of the interval.
19    pub lower: f64,
20    /// Upper bound of the interval.
21    pub upper: f64,
22    /// Point estimate (e.g., sample proportion).
23    pub point_estimate: f64,
24    /// Confidence level, e.g., 0.95 for a 95 % interval.
25    pub confidence_level: f64,
26    /// Human-readable label for the method used.
27    pub method: &'static str,
28}
29
30/// Result of a chi-squared goodness-of-fit test.
31#[derive(Debug, Clone)]
32pub struct ChiSquaredResult {
33    /// The chi-squared statistic.
34    pub statistic: f64,
35    /// Degrees of freedom (number of categories minus one).
36    pub degrees_of_freedom: usize,
37    /// Approximate p-value.
38    pub p_value: f64,
39    /// Whether the result is significant at the 0.05 level.
40    pub significant: bool,
41}
42
43/// Tracks a running sequence of estimates and detects convergence.
44pub struct ConvergenceMonitor {
45    estimates: Vec<f64>,
46    window_size: usize,
47}
48
49// ---------------------------------------------------------------------------
50// Helpers: inverse normal CDF (z-score)
51// ---------------------------------------------------------------------------
52
53/// Approximate the z-score (inverse standard-normal CDF) for a given two-sided
54/// confidence level using the rational approximation of Abramowitz & Stegun
55/// (formula 26.2.23).
56///
57/// For confidence level `c`, we compute the upper quantile at
58/// `p = (1 + c) / 2` and return the corresponding z-value.
59///
60/// # Panics
61///
62/// Panics if `confidence` is not in the open interval (0, 1).
63pub fn z_score(confidence: f64) -> f64 {
64    assert!(
65        confidence > 0.0 && confidence < 1.0,
66        "confidence must be in (0, 1)"
67    );
68
69    let p = (1.0 + confidence) / 2.0; // upper tail probability
70    // 1 - p is the tail area; for p close to 1 this is small and positive.
71    let tail = 1.0 - p;
72
73    // Rational approximation: for tail area `q`, set t = sqrt(-2 ln q).
74    let t = (-2.0_f64 * tail.ln()).sqrt();
75
76    // Coefficients (Abramowitz & Stegun 26.2.23)
77    let c0 = 2.515517;
78    let c1 = 0.802853;
79    let c2 = 0.010328;
80    let d1 = 1.432788;
81    let d2 = 0.189269;
82    let d3 = 0.001308;
83
84    t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
85}
86
87// ---------------------------------------------------------------------------
88// Wilson score interval
89// ---------------------------------------------------------------------------
90
91/// Compute the Wilson score confidence interval for a binomial proportion.
92///
93/// The Wilson interval is centred near the MLE but accounts for the discrete
94/// nature of the binomial and never produces bounds outside [0, 1].
95///
96/// # Arguments
97///
98/// * `successes` -- number of successes observed.
99/// * `trials`    -- total number of trials (must be > 0).
100/// * `confidence` -- desired confidence level in (0, 1).
101pub fn wilson_interval(successes: usize, trials: usize, confidence: f64) -> ConfidenceInterval {
102    assert!(trials > 0, "trials must be > 0");
103    assert!(
104        confidence > 0.0 && confidence < 1.0,
105        "confidence must be in (0, 1)"
106    );
107
108    let n = trials as f64;
109    let p_hat = successes as f64 / n;
110    let z = z_score(confidence);
111    let z2 = z * z;
112
113    let denom = 1.0 + z2 / n;
114    let centre = (p_hat + z2 / (2.0 * n)) / denom;
115    let half_width = z * (p_hat * (1.0 - p_hat) / n + z2 / (4.0 * n * n)).sqrt() / denom;
116
117    let lower = (centre - half_width).max(0.0);
118    let upper = (centre + half_width).min(1.0);
119
120    ConfidenceInterval {
121        lower,
122        upper,
123        point_estimate: p_hat,
124        confidence_level: confidence,
125        method: "wilson",
126    }
127}
128
129// ---------------------------------------------------------------------------
130// Clopper-Pearson exact interval
131// ---------------------------------------------------------------------------
132
133/// Compute the Clopper-Pearson (exact) confidence interval for a binomial
134/// proportion via bisection on the binomial CDF.
135///
136/// This interval is conservative -- it guarantees at least the nominal coverage
137/// probability, but may be wider than necessary.
138///
139/// # Arguments
140///
141/// * `successes` -- number of successes observed.
142/// * `trials`    -- total number of trials (must be > 0).
143/// * `confidence` -- desired confidence level in (0, 1).
144pub fn clopper_pearson(successes: usize, trials: usize, confidence: f64) -> ConfidenceInterval {
145    assert!(trials > 0, "trials must be > 0");
146    assert!(
147        confidence > 0.0 && confidence < 1.0,
148        "confidence must be in (0, 1)"
149    );
150
151    let alpha = 1.0 - confidence;
152    let n = trials;
153    let k = successes;
154    let p_hat = k as f64 / n as f64;
155
156    // Lower bound: find p such that P(X >= k | n, p) = alpha/2,
157    // equivalently P(X <= k-1 | n, p) = 1 - alpha/2.
158    let lower = if k == 0 {
159        0.0
160    } else {
161        bisect_binomial_cdf(n, k - 1, 1.0 - alpha / 2.0)
162    };
163
164    // Upper bound: find p such that P(X <= k | n, p) = alpha/2.
165    let upper = if k == n {
166        1.0
167    } else {
168        bisect_binomial_cdf(n, k, alpha / 2.0)
169    };
170
171    ConfidenceInterval {
172        lower,
173        upper,
174        point_estimate: p_hat,
175        confidence_level: confidence,
176        method: "clopper-pearson",
177    }
178}
179
180/// Use bisection to find `p` such that `binomial_cdf(n, k, p) = target`.
181///
182/// `binomial_cdf(n, k, p)` = sum_{i=0}^{k} C(n,i) p^i (1-p)^{n-i}.
183fn bisect_binomial_cdf(n: usize, k: usize, target: f64) -> f64 {
184    let mut lo = 0.0_f64;
185    let mut hi = 1.0_f64;
186
187    for _ in 0..200 {
188        let mid = (lo + hi) / 2.0;
189        let cdf = binomial_cdf(n, k, mid);
190        if cdf < target {
191            // CDF is too small; increasing p increases CDF, so move lo up.
192            // Actually: increasing p *decreases* P(X <= k) when k < n.
193            // Let's think carefully:
194            //   P(X <= k | p) is monotonically *decreasing* in p for k < n.
195            //   So if cdf < target we need to *decrease* p.
196            hi = mid;
197        } else {
198            lo = mid;
199        }
200
201        if (hi - lo) < 1e-15 {
202            break;
203        }
204    }
205    (lo + hi) / 2.0
206}
207
208/// Evaluate the binomial CDF: P(X <= k) where X ~ Bin(n, p).
209///
210/// Uses a log-space computation to avoid overflow for large n.
211fn binomial_cdf(n: usize, k: usize, p: f64) -> f64 {
212    if p <= 0.0 {
213        return 1.0;
214    }
215    if p >= 1.0 {
216        return if k >= n { 1.0 } else { 0.0 };
217    }
218    if k >= n {
219        return 1.0;
220    }
221
222    // Use the regularised incomplete beta function identity:
223    //   P(X <= k | n, p) = I_{1-p}(n - k, k + 1)
224    // We compute the CDF directly via summation in log-space for moderate n.
225    // For very large n this could be slow, but quantum shot counts are typically
226    // at most millions, and this is called from bisection which only needs
227    // ~200 evaluations.
228    let mut cdf = 0.0_f64;
229    // log_binom accumulates log(C(n, i)) incrementally.
230    let ln_p = p.ln();
231    let ln_1mp = (1.0 - p).ln();
232
233    // Start with i = 0: C(n,0) * p^0 * (1-p)^n
234    let mut log_binom = 0.0_f64; // log C(n, 0) = 0
235    cdf += (log_binom + ln_1mp * n as f64).exp();
236
237    for i in 1..=k {
238        // log C(n, i) = log C(n, i-1) + log(n - i + 1) - log(i)
239        log_binom += ((n - i + 1) as f64).ln() - (i as f64).ln();
240        let log_term = log_binom + ln_p * i as f64 + ln_1mp * (n - i) as f64;
241        cdf += log_term.exp();
242    }
243
244    cdf.min(1.0).max(0.0)
245}
246
247// ---------------------------------------------------------------------------
248// Expectation value confidence interval
249// ---------------------------------------------------------------------------
250
251/// Compute a confidence interval for the expectation value <Z> of a given
252/// qubit from shot counts.
253///
254/// For qubit `q`, the Z expectation value is `P(0) - P(1)` where P(0) is the
255/// fraction of shots where qubit `q` measured `false` and P(1) where it
256/// measured `true`.
257///
258/// The standard error is computed from the multinomial variance:
259///   Var(<Z>) = (1 - <Z>^2) / n
260///   SE       = sqrt(Var(<Z>) / n)  ... but more precisely, each shot produces
261///   a value +1 or -1 so Var = 1 - mean^2, and SE = sqrt(Var / n).
262///
263/// The returned interval is `<Z> +/- z * SE`.
264pub fn expectation_confidence(
265    counts: &HashMap<Vec<bool>, usize>,
266    qubit: u32,
267    confidence: f64,
268) -> ConfidenceInterval {
269    assert!(
270        confidence > 0.0 && confidence < 1.0,
271        "confidence must be in (0, 1)"
272    );
273
274    let mut n_zero: usize = 0;
275    let mut n_one: usize = 0;
276
277    for (bits, &count) in counts {
278        if let Some(&b) = bits.get(qubit as usize) {
279            if b {
280                n_one += count;
281            } else {
282                n_zero += count;
283            }
284        }
285    }
286
287    let total = (n_zero + n_one) as f64;
288    assert!(total > 0.0, "no shots found for the given qubit");
289
290    let p0 = n_zero as f64 / total;
291    let p1 = n_one as f64 / total;
292    let exp_z = p0 - p1; // <Z>
293
294    // Each shot yields +1 (qubit=0) or -1 (qubit=1).
295    // Variance of a single shot = E[X^2] - E[X]^2 = 1 - exp_z^2.
296    let var_single = 1.0 - exp_z * exp_z;
297    let se = (var_single / total).sqrt();
298
299    let z = z_score(confidence);
300    let lower = (exp_z - z * se).max(-1.0);
301    let upper = (exp_z + z * se).min(1.0);
302
303    ConfidenceInterval {
304        lower,
305        upper,
306        point_estimate: exp_z,
307        confidence_level: confidence,
308        method: "expectation-z-se",
309    }
310}
311
312// ---------------------------------------------------------------------------
313// Shot budget calculator
314// ---------------------------------------------------------------------------
315
316/// Compute the minimum number of shots required so that the additive error of
317/// an empirical probability is at most `epsilon` with probability at least
318/// `1 - delta`, using the Hoeffding bound.
319///
320/// Formula: N >= ln(2 / delta) / (2 * epsilon^2)
321///
322/// # Panics
323///
324/// Panics if `epsilon` or `delta` is not in (0, 1).
325pub fn required_shots(epsilon: f64, delta: f64) -> usize {
326    assert!(
327        epsilon > 0.0 && epsilon < 1.0,
328        "epsilon must be in (0, 1)"
329    );
330    assert!(delta > 0.0 && delta < 1.0, "delta must be in (0, 1)");
331
332    let n = (2.0_f64 / delta).ln() / (2.0 * epsilon * epsilon);
333    n.ceil() as usize
334}
335
336// ---------------------------------------------------------------------------
337// Total variation distance
338// ---------------------------------------------------------------------------
339
340/// Compute the total variation distance between two empirical distributions
341/// given as shot-count histograms.
342///
343/// TVD = 0.5 * sum_i |p_i - q_i| over all bitstrings present in either
344/// distribution.
345pub fn total_variation_distance(
346    p: &HashMap<Vec<bool>, usize>,
347    q: &HashMap<Vec<bool>, usize>,
348) -> f64 {
349    let total_p: f64 = p.values().sum::<usize>() as f64;
350    let total_q: f64 = q.values().sum::<usize>() as f64;
351
352    if total_p == 0.0 && total_q == 0.0 {
353        return 0.0;
354    }
355
356    // Collect all keys from both distributions.
357    let mut all_keys: Vec<&Vec<bool>> = Vec::new();
358    for key in p.keys() {
359        all_keys.push(key);
360    }
361    for key in q.keys() {
362        if !p.contains_key(key) {
363            all_keys.push(key);
364        }
365    }
366
367    let mut tvd = 0.0_f64;
368    for key in &all_keys {
369        let pi = if total_p > 0.0 {
370            *p.get(*key).unwrap_or(&0) as f64 / total_p
371        } else {
372            0.0
373        };
374        let qi = if total_q > 0.0 {
375            *q.get(*key).unwrap_or(&0) as f64 / total_q
376        } else {
377            0.0
378        };
379        tvd += (pi - qi).abs();
380    }
381
382    0.5 * tvd
383}
384
385// ---------------------------------------------------------------------------
386// Chi-squared test
387// ---------------------------------------------------------------------------
388
389/// Perform a chi-squared goodness-of-fit test comparing an observed
390/// distribution to an expected distribution.
391///
392/// The expected distribution is scaled to match the total number of observed
393/// counts. The p-value is approximated using the Wilson-Hilferty cube-root
394/// transformation of the chi-squared CDF.
395///
396/// # Panics
397///
398/// Panics if there are no categories or if the expected distribution has zero
399/// total counts.
400pub fn chi_squared_test(
401    observed: &HashMap<Vec<bool>, usize>,
402    expected: &HashMap<Vec<bool>, usize>,
403) -> ChiSquaredResult {
404    let total_observed: f64 = observed.values().sum::<usize>() as f64;
405    let total_expected: f64 = expected.values().sum::<usize>() as f64;
406
407    assert!(
408        total_expected > 0.0,
409        "expected distribution must have nonzero total"
410    );
411
412    // Collect all keys.
413    let mut all_keys: Vec<&Vec<bool>> = Vec::new();
414    for key in observed.keys() {
415        all_keys.push(key);
416    }
417    for key in expected.keys() {
418        if !observed.contains_key(key) {
419            all_keys.push(key);
420        }
421    }
422
423    let mut statistic = 0.0_f64;
424    let mut num_categories = 0_usize;
425
426    for key in &all_keys {
427        let o = *observed.get(*key).unwrap_or(&0) as f64;
428        // Scale expected counts to match observed total.
429        let e_raw = *expected.get(*key).unwrap_or(&0) as f64;
430        let e = e_raw * total_observed / total_expected;
431
432        if e > 0.0 {
433            statistic += (o - e) * (o - e) / e;
434            num_categories += 1;
435        }
436    }
437
438    let df = if num_categories > 1 {
439        num_categories - 1
440    } else {
441        1
442    };
443
444    let p_value = chi_squared_survival(statistic, df);
445
446    ChiSquaredResult {
447        statistic,
448        degrees_of_freedom: df,
449        p_value,
450        significant: p_value < 0.05,
451    }
452}
453
454/// Approximate the survival function (1 - CDF) of the chi-squared distribution
455/// using the Wilson-Hilferty normal approximation.
456///
457/// For chi-squared random variable X with k degrees of freedom:
458///   (X/k)^{1/3} is approximately normal with mean 1 - 2/(9k)
459///   and variance 2/(9k).
460///
461/// So P(X > x) approx P(Z > z) where
462///   z = ((x/k)^{1/3} - (1 - 2/(9k))) / sqrt(2/(9k))
463/// and P(Z > z) = 1 - Phi(z) = Phi(-z).
464fn chi_squared_survival(x: f64, df: usize) -> f64 {
465    if df == 0 {
466        return if x > 0.0 { 0.0 } else { 1.0 };
467    }
468
469    if x <= 0.0 {
470        return 1.0;
471    }
472
473    let k = df as f64;
474    let term = 2.0 / (9.0 * k);
475    let cube_root = (x / k).powf(1.0 / 3.0);
476    let z = (cube_root - (1.0 - term)) / term.sqrt();
477
478    // P(Z > z) = 1 - Phi(z) = Phi(-z)
479    normal_cdf(-z)
480}
481
482/// Approximate the standard normal CDF using the Abramowitz & Stegun
483/// approximation (formula 7.1.26).
484fn normal_cdf(x: f64) -> f64 {
485    // Use the error function relation: Phi(x) = 0.5 * (1 + erf(x / sqrt(2)))
486    // We approximate erf via the Horner form of the A&S rational approximation.
487    let sign = if x < 0.0 { -1.0 } else { 1.0 };
488    let x_abs = x.abs();
489
490    let t = 1.0 / (1.0 + 0.2316419 * x_abs);
491    let d = 0.3989422804014327; // 1/sqrt(2*pi)
492    let p = d * (-x_abs * x_abs / 2.0).exp();
493
494    let poly = t
495        * (0.319381530
496            + t * (-0.356563782
497                + t * (1.781477937 + t * (-1.821255978 + t * 1.330274429))));
498
499    if sign > 0.0 {
500        1.0 - p * poly
501    } else {
502        p * poly
503    }
504}
505
506// ---------------------------------------------------------------------------
507// Convergence monitor
508// ---------------------------------------------------------------------------
509
510impl ConvergenceMonitor {
511    /// Create a new monitor with the given window size.
512    ///
513    /// The monitor considers the sequence converged when the last
514    /// `window_size` estimates all lie within `epsilon` of each other.
515    pub fn new(window_size: usize) -> Self {
516        assert!(window_size > 0, "window_size must be > 0");
517        Self {
518            estimates: Vec::new(),
519            window_size,
520        }
521    }
522
523    /// Record a new estimate.
524    pub fn add_estimate(&mut self, value: f64) {
525        self.estimates.push(value);
526    }
527
528    /// Check whether the last `window_size` estimates have converged: i.e.,
529    /// the maximum minus the minimum within the window is less than `epsilon`.
530    pub fn has_converged(&self, epsilon: f64) -> bool {
531        if self.estimates.len() < self.window_size {
532            return false;
533        }
534
535        let window = &self.estimates[self.estimates.len() - self.window_size..];
536        let min = window
537            .iter()
538            .copied()
539            .fold(f64::INFINITY, f64::min);
540        let max = window
541            .iter()
542            .copied()
543            .fold(f64::NEG_INFINITY, f64::max);
544
545        (max - min) < epsilon
546    }
547
548    /// Return the most recent estimate, or `None` if no estimates have been
549    /// added.
550    pub fn current_estimate(&self) -> Option<f64> {
551        self.estimates.last().copied()
552    }
553}
554
555// ===========================================================================
556// Tests
557// ===========================================================================
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    // -----------------------------------------------------------------------
564    // z_score
565    // -----------------------------------------------------------------------
566
567    #[test]
568    fn z_score_95() {
569        let z = z_score(0.95);
570        assert!(
571            (z - 1.96).abs() < 0.01,
572            "z_score(0.95) = {z}, expected ~1.96"
573        );
574    }
575
576    #[test]
577    fn z_score_99() {
578        let z = z_score(0.99);
579        assert!(
580            (z - 2.576).abs() < 0.02,
581            "z_score(0.99) = {z}, expected ~2.576"
582        );
583    }
584
585    #[test]
586    fn z_score_90() {
587        let z = z_score(0.90);
588        assert!(
589            (z - 1.645).abs() < 0.01,
590            "z_score(0.90) = {z}, expected ~1.645"
591        );
592    }
593
594    // -----------------------------------------------------------------------
595    // Wilson interval
596    // -----------------------------------------------------------------------
597
598    #[test]
599    fn wilson_contains_true_proportion() {
600        // 50 successes out of 100 trials, true p = 0.5
601        let ci = wilson_interval(50, 100, 0.95);
602        assert!(ci.lower < 0.5 && ci.upper > 0.5, "Wilson CI should contain 0.5: {ci:?}");
603        assert_eq!(ci.method, "wilson");
604        assert!((ci.point_estimate - 0.5).abs() < 1e-12);
605    }
606
607    #[test]
608    fn wilson_asymmetric() {
609        // 1 success out of 100 -- the interval should still be reasonable.
610        let ci = wilson_interval(1, 100, 0.95);
611        assert!(ci.lower >= 0.0);
612        assert!(ci.upper <= 1.0);
613        assert!(ci.lower < 0.01);
614        assert!(ci.upper > 0.01);
615    }
616
617    #[test]
618    fn wilson_zero_successes() {
619        let ci = wilson_interval(0, 100, 0.95);
620        assert_eq!(ci.lower, 0.0);
621        assert!(ci.upper > 0.0);
622        assert!((ci.point_estimate - 0.0).abs() < 1e-12);
623    }
624
625    // -----------------------------------------------------------------------
626    // Clopper-Pearson
627    // -----------------------------------------------------------------------
628
629    #[test]
630    fn clopper_pearson_contains_true_proportion() {
631        let ci = clopper_pearson(50, 100, 0.95);
632        assert!(
633            ci.lower < 0.5 && ci.upper > 0.5,
634            "Clopper-Pearson CI should contain 0.5: {ci:?}"
635        );
636        assert_eq!(ci.method, "clopper-pearson");
637    }
638
639    #[test]
640    fn clopper_pearson_is_conservative() {
641        // Clopper-Pearson should be wider than Wilson for the same data.
642        let cp = clopper_pearson(50, 100, 0.95);
643        let w = wilson_interval(50, 100, 0.95);
644
645        let cp_width = cp.upper - cp.lower;
646        let w_width = w.upper - w.lower;
647
648        assert!(
649            cp_width >= w_width - 1e-10,
650            "Clopper-Pearson width ({cp_width}) should be >= Wilson width ({w_width})"
651        );
652    }
653
654    #[test]
655    fn clopper_pearson_edge_zero() {
656        let ci = clopper_pearson(0, 100, 0.95);
657        assert_eq!(ci.lower, 0.0);
658        assert!(ci.upper > 0.0);
659    }
660
661    #[test]
662    fn clopper_pearson_edge_all() {
663        let ci = clopper_pearson(100, 100, 0.95);
664        assert_eq!(ci.upper, 1.0);
665        assert!(ci.lower < 1.0);
666    }
667
668    // -----------------------------------------------------------------------
669    // Expectation value confidence
670    // -----------------------------------------------------------------------
671
672    #[test]
673    fn expectation_all_zero() {
674        // All shots measure |0>: <Z> = 1.0
675        let mut counts = HashMap::new();
676        counts.insert(vec![false], 1000);
677        let ci = expectation_confidence(&counts, 0, 0.95);
678        assert!((ci.point_estimate - 1.0).abs() < 1e-12);
679        assert!(ci.lower <= 1.0);
680        assert!(ci.upper >= 1.0 - 1e-6);
681    }
682
683    #[test]
684    fn expectation_all_one() {
685        // All shots measure |1>: <Z> = -1.0
686        let mut counts = HashMap::new();
687        counts.insert(vec![true], 1000);
688        let ci = expectation_confidence(&counts, 0, 0.95);
689        assert!((ci.point_estimate - (-1.0)).abs() < 1e-12);
690    }
691
692    #[test]
693    fn expectation_balanced() {
694        // Equal |0> and |1>: <Z> = 0.0
695        let mut counts = HashMap::new();
696        counts.insert(vec![false], 500);
697        counts.insert(vec![true], 500);
698        let ci = expectation_confidence(&counts, 0, 0.95);
699        assert!(
700            ci.point_estimate.abs() < 1e-12,
701            "expected 0.0, got {}",
702            ci.point_estimate
703        );
704        assert!(ci.lower < 0.0);
705        assert!(ci.upper > 0.0);
706    }
707
708    #[test]
709    fn expectation_multi_qubit() {
710        // Two-qubit system: qubit 0 always |0>, qubit 1 always |1>
711        let mut counts = HashMap::new();
712        counts.insert(vec![false, true], 1000);
713        let ci0 = expectation_confidence(&counts, 0, 0.95);
714        let ci1 = expectation_confidence(&counts, 1, 0.95);
715        assert!((ci0.point_estimate - 1.0).abs() < 1e-12);
716        assert!((ci1.point_estimate - (-1.0)).abs() < 1e-12);
717    }
718
719    // -----------------------------------------------------------------------
720    // Required shots
721    // -----------------------------------------------------------------------
722
723    #[test]
724    fn required_shots_standard() {
725        let n = required_shots(0.01, 0.05);
726        // ln(2/0.05) / (2 * 0.01^2) = ln(40) / 0.0002 = 3.6889 / 0.0002 = 18444.7
727        assert!(
728            (n as i64 - 18445).abs() <= 1,
729            "required_shots(0.01, 0.05) = {n}, expected ~18445"
730        );
731    }
732
733    #[test]
734    fn required_shots_loose() {
735        let n = required_shots(0.1, 0.1);
736        // ln(20) / 0.02 = 2.9957 / 0.02 = 149.79 -> 150
737        assert!(n >= 149 && n <= 151, "expected ~150, got {n}");
738    }
739
740    // -----------------------------------------------------------------------
741    // Total variation distance
742    // -----------------------------------------------------------------------
743
744    #[test]
745    fn tvd_identical() {
746        let mut p = HashMap::new();
747        p.insert(vec![false, false], 250);
748        p.insert(vec![false, true], 250);
749        p.insert(vec![true, false], 250);
750        p.insert(vec![true, true], 250);
751
752        let tvd = total_variation_distance(&p, &p);
753        assert!(tvd.abs() < 1e-12, "TVD of identical distributions should be 0, got {tvd}");
754    }
755
756    #[test]
757    fn tvd_completely_different() {
758        let mut p = HashMap::new();
759        p.insert(vec![false], 1000);
760
761        let mut q = HashMap::new();
762        q.insert(vec![true], 1000);
763
764        let tvd = total_variation_distance(&p, &q);
765        assert!(
766            (tvd - 1.0).abs() < 1e-12,
767            "TVD of completely different distributions should be 1.0, got {tvd}"
768        );
769    }
770
771    #[test]
772    fn tvd_partial_overlap() {
773        let mut p = HashMap::new();
774        p.insert(vec![false], 600);
775        p.insert(vec![true], 400);
776
777        let mut q = HashMap::new();
778        q.insert(vec![false], 400);
779        q.insert(vec![true], 600);
780
781        let tvd = total_variation_distance(&p, &q);
782        // |0.6 - 0.4| + |0.4 - 0.6| = 0.4, times 0.5 = 0.2
783        assert!(
784            (tvd - 0.2).abs() < 1e-12,
785            "expected 0.2, got {tvd}"
786        );
787    }
788
789    #[test]
790    fn tvd_empty() {
791        let p: HashMap<Vec<bool>, usize> = HashMap::new();
792        let q: HashMap<Vec<bool>, usize> = HashMap::new();
793        let tvd = total_variation_distance(&p, &q);
794        assert!(tvd.abs() < 1e-12);
795    }
796
797    // -----------------------------------------------------------------------
798    // Chi-squared test
799    // -----------------------------------------------------------------------
800
801    #[test]
802    fn chi_squared_matching() {
803        // Observed matches expected perfectly.
804        let mut obs = HashMap::new();
805        obs.insert(vec![false, false], 250);
806        obs.insert(vec![false, true], 250);
807        obs.insert(vec![true, false], 250);
808        obs.insert(vec![true, true], 250);
809
810        let result = chi_squared_test(&obs, &obs);
811        assert!(
812            result.statistic < 1e-12,
813            "statistic should be ~0 for identical distributions, got {}",
814            result.statistic
815        );
816        assert!(
817            result.p_value > 0.05,
818            "p-value should be high for matching distributions, got {}",
819            result.p_value
820        );
821        assert!(!result.significant);
822    }
823
824    #[test]
825    fn chi_squared_very_different() {
826        let mut obs = HashMap::new();
827        obs.insert(vec![false], 1000);
828        obs.insert(vec![true], 0);
829
830        let mut exp = HashMap::new();
831        exp.insert(vec![false], 500);
832        exp.insert(vec![true], 500);
833
834        let result = chi_squared_test(&obs, &exp);
835        assert!(result.statistic > 100.0, "statistic should be large");
836        assert!(result.p_value < 0.05, "p-value should be small: {}", result.p_value);
837        assert!(result.significant);
838    }
839
840    #[test]
841    fn chi_squared_degrees_of_freedom() {
842        let mut obs = HashMap::new();
843        obs.insert(vec![false, false], 100);
844        obs.insert(vec![false, true], 100);
845        obs.insert(vec![true, false], 100);
846        obs.insert(vec![true, true], 100);
847
848        let result = chi_squared_test(&obs, &obs);
849        assert_eq!(result.degrees_of_freedom, 3);
850    }
851
852    // -----------------------------------------------------------------------
853    // Convergence monitor
854    // -----------------------------------------------------------------------
855
856    #[test]
857    fn convergence_detects_stable() {
858        let mut monitor = ConvergenceMonitor::new(5);
859        // Add a sequence that stabilises.
860        for &v in &[0.5, 0.52, 0.49, 0.501, 0.499, 0.5001, 0.4999, 0.5002, 0.4998, 0.5001] {
861            monitor.add_estimate(v);
862        }
863        assert!(
864            monitor.has_converged(0.01),
865            "should have converged: last 5 values are within 0.01"
866        );
867    }
868
869    #[test]
870    fn convergence_rejects_unstable() {
871        let mut monitor = ConvergenceMonitor::new(5);
872        for &v in &[0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9] {
873            monitor.add_estimate(v);
874        }
875        assert!(
876            !monitor.has_converged(0.01),
877            "should NOT have converged: values oscillate widely"
878        );
879    }
880
881    #[test]
882    fn convergence_insufficient_data() {
883        let mut monitor = ConvergenceMonitor::new(10);
884        monitor.add_estimate(1.0);
885        monitor.add_estimate(1.0);
886        assert!(
887            !monitor.has_converged(0.1),
888            "not enough data for window_size=10"
889        );
890    }
891
892    #[test]
893    fn convergence_current_estimate() {
894        let mut monitor = ConvergenceMonitor::new(3);
895        assert_eq!(monitor.current_estimate(), None);
896        monitor.add_estimate(42.0);
897        assert_eq!(monitor.current_estimate(), Some(42.0));
898        monitor.add_estimate(43.0);
899        assert_eq!(monitor.current_estimate(), Some(43.0));
900    }
901
902    // -----------------------------------------------------------------------
903    // Binomial CDF helper
904    // -----------------------------------------------------------------------
905
906    #[test]
907    fn binomial_cdf_edge_cases() {
908        // P(X <= 10 | 10, 0.5) should be 1.0
909        let c = binomial_cdf(10, 10, 0.5);
910        assert!((c - 1.0).abs() < 1e-12);
911
912        // P(X <= 0 | 10, 0.5) = (0.5)^10 ~ 0.000977
913        let c = binomial_cdf(10, 0, 0.5);
914        assert!((c - 0.0009765625).abs() < 1e-8);
915    }
916
917    // -----------------------------------------------------------------------
918    // Normal CDF helper
919    // -----------------------------------------------------------------------
920
921    #[test]
922    fn normal_cdf_values() {
923        // Phi(0) = 0.5
924        assert!((normal_cdf(0.0) - 0.5).abs() < 1e-6);
925
926        // Phi(1.96) ~ 0.975
927        assert!((normal_cdf(1.96) - 0.975).abs() < 0.002);
928
929        // Phi(-1.96) ~ 0.025
930        assert!((normal_cdf(-1.96) - 0.025).abs() < 0.002);
931    }
932}