Skip to main content

tacet_core/adaptive/
state.rs

1//! State management for adaptive sampling loop (no_std compatible).
2//!
3//! This module defines the state maintained during the adaptive sampling process,
4//! including sample storage, posterior tracking, and KL divergence history.
5//!
6//! Time tracking is handled by the caller - this module is stateless with respect
7//! to wall-clock time, making it suitable for no_std environments like SGX enclaves.
8
9use alloc::collections::VecDeque;
10use alloc::vec::Vec;
11
12use crate::statistics::{OnlineStats, StatsSnapshot};
13
14use super::{kl_divergence_gaussian, CalibrationSnapshot, Posterior};
15
16/// State maintained during adaptive sampling loop.
17///
18/// This struct accumulates timing samples and tracks the evolution of the
19/// posterior distribution across batches, enabling quality gate checks like
20/// KL divergence monitoring.
21///
22/// Also maintains online statistics (mean, variance, lag-1 autocorrelation)
23/// for condition drift detection between calibration and post-test phases.
24///
25/// # Time Tracking
26///
27/// This struct does NOT track elapsed time internally. The caller must provide
28/// `elapsed_secs` to functions that need it (e.g., quality gate checks). This
29/// allows use in no_std environments where `std::time::Instant` is unavailable.
30pub struct AdaptiveState {
31    /// Baseline class timing samples (in cycles/ticks/native units).
32    pub baseline_samples: Vec<u64>,
33
34    /// Sample class timing samples (in cycles/ticks/native units).
35    pub sample_samples: Vec<u64>,
36
37    /// Previous posterior for KL divergence tracking.
38    /// None until we have at least one posterior computed.
39    pub previous_posterior: Option<Posterior>,
40
41    /// Recent KL divergences (last 5 batches) for learning rate monitoring.
42    /// If sum of recent KL < 0.001, learning has stalled.
43    pub recent_kl_divergences: VecDeque<f64>,
44
45    /// Number of batches collected so far.
46    pub batch_count: usize,
47
48    /// Online statistics tracker for baseline class (for drift detection).
49    baseline_stats: OnlineStats,
50
51    /// Online statistics tracker for sample class (for drift detection).
52    sample_stats: OnlineStats,
53
54    /// Conversion factor from native units to nanoseconds.
55    /// Set when first batch is added with ns_per_tick context.
56    ns_per_tick: Option<f64>,
57}
58
59impl AdaptiveState {
60    /// Create a new empty adaptive state.
61    pub fn new() -> Self {
62        Self {
63            baseline_samples: Vec::new(),
64            sample_samples: Vec::new(),
65            previous_posterior: None,
66            recent_kl_divergences: VecDeque::with_capacity(5),
67            batch_count: 0,
68            baseline_stats: OnlineStats::new(),
69            sample_stats: OnlineStats::new(),
70            ns_per_tick: None,
71        }
72    }
73
74    /// Create a new adaptive state with pre-allocated capacity.
75    pub fn with_capacity(expected_samples: usize) -> Self {
76        Self {
77            baseline_samples: Vec::with_capacity(expected_samples),
78            sample_samples: Vec::with_capacity(expected_samples),
79            previous_posterior: None,
80            recent_kl_divergences: VecDeque::with_capacity(5),
81            batch_count: 0,
82            baseline_stats: OnlineStats::new(),
83            sample_stats: OnlineStats::new(),
84            ns_per_tick: None,
85        }
86    }
87
88    /// Get the total number of samples per class.
89    pub fn n_total(&self) -> usize {
90        self.baseline_samples.len()
91    }
92
93    /// Add a batch of samples to the state.
94    ///
95    /// Both baseline and sample vectors should have the same length.
96    /// Note: This method does not track online statistics since ns_per_tick is not known.
97    /// Use `add_batch_with_conversion` if you need drift detection.
98    pub fn add_batch(&mut self, baseline: Vec<u64>, sample: Vec<u64>) {
99        debug_assert_eq!(
100            baseline.len(),
101            sample.len(),
102            "Baseline and sample batch sizes must match"
103        );
104        self.baseline_samples.extend(baseline);
105        self.sample_samples.extend(sample);
106        self.batch_count += 1;
107    }
108
109    /// Add a batch of samples and track online statistics for drift detection.
110    ///
111    /// Both baseline and sample vectors should have the same length.
112    ///
113    /// # Arguments
114    ///
115    /// * `baseline` - Baseline class timing samples (in native units)
116    /// * `sample` - Sample class timing samples (in native units)
117    /// * `ns_per_tick` - Conversion factor from native units to nanoseconds
118    pub fn add_batch_with_conversion(
119        &mut self,
120        baseline: Vec<u64>,
121        sample: Vec<u64>,
122        ns_per_tick: f64,
123    ) {
124        debug_assert_eq!(
125            baseline.len(),
126            sample.len(),
127            "Baseline and sample batch sizes must match"
128        );
129
130        // Store the conversion factor for later use
131        self.ns_per_tick = Some(ns_per_tick);
132
133        // Update online statistics with nanosecond-converted values
134        for &t in &baseline {
135            self.baseline_stats.update(t as f64 * ns_per_tick);
136        }
137        for &t in &sample {
138            self.sample_stats.update(t as f64 * ns_per_tick);
139        }
140
141        // Store raw samples
142        self.baseline_samples.extend(baseline);
143        self.sample_samples.extend(sample);
144        self.batch_count += 1;
145    }
146
147    /// Update KL divergence history with a new value.
148    ///
149    /// Maintains a sliding window of the last 5 KL divergences for
150    /// learning rate monitoring.
151    pub fn update_kl(&mut self, kl: f64) {
152        self.recent_kl_divergences.push_back(kl);
153        if self.recent_kl_divergences.len() > 5 {
154            self.recent_kl_divergences.pop_front();
155        }
156    }
157
158    /// Get the sum of recent KL divergences.
159    ///
160    /// Used to detect learning stall (sum < 0.001 indicates posterior
161    /// has stopped updating despite new data).
162    pub fn recent_kl_sum(&self) -> f64 {
163        self.recent_kl_divergences.iter().sum()
164    }
165
166    /// Check if we have enough KL history for learning rate assessment.
167    pub fn has_kl_history(&self) -> bool {
168        self.recent_kl_divergences.len() >= 5
169    }
170
171    /// Update the posterior and track KL divergence.
172    ///
173    /// Returns the KL divergence from the previous posterior, or 0.0 if
174    /// this is the first posterior.
175    pub fn update_posterior(&mut self, new_posterior: Posterior) -> f64 {
176        let kl = if let Some(ref prev) = self.previous_posterior {
177            kl_divergence_gaussian(&new_posterior, prev)
178        } else {
179            0.0
180        };
181
182        self.previous_posterior = Some(new_posterior);
183
184        if kl.is_finite() {
185            self.update_kl(kl);
186        }
187
188        kl
189    }
190
191    /// Get the current posterior, if computed.
192    pub fn current_posterior(&self) -> Option<&Posterior> {
193        self.previous_posterior.as_ref()
194    }
195
196    /// Convert baseline samples to f64 nanoseconds.
197    pub fn baseline_ns(&self, ns_per_tick: f64) -> Vec<f64> {
198        self.baseline_samples
199            .iter()
200            .map(|&t| t as f64 * ns_per_tick)
201            .collect()
202    }
203
204    /// Convert sample samples to f64 nanoseconds.
205    pub fn sample_ns(&self, ns_per_tick: f64) -> Vec<f64> {
206        self.sample_samples
207            .iter()
208            .map(|&t| t as f64 * ns_per_tick)
209            .collect()
210    }
211
212    /// Get the current online statistics for the baseline class.
213    ///
214    /// Returns `None` if no samples have been added with conversion tracking.
215    pub fn baseline_stats(&self) -> Option<StatsSnapshot> {
216        if self.baseline_stats.count() < 2 {
217            return None;
218        }
219        Some(self.baseline_stats.finalize())
220    }
221
222    /// Get the current online statistics for the sample class.
223    ///
224    /// Returns `None` if no samples have been added with conversion tracking.
225    pub fn sample_stats(&self) -> Option<StatsSnapshot> {
226        if self.sample_stats.count() < 2 {
227            return None;
228        }
229        Some(self.sample_stats.finalize())
230    }
231
232    /// Get a CalibrationSnapshot from the current online statistics.
233    ///
234    /// Returns `None` if insufficient samples have been tracked.
235    pub fn get_stats_snapshot(&self) -> Option<CalibrationSnapshot> {
236        let baseline = self.baseline_stats()?;
237        let sample = self.sample_stats()?;
238        Some(CalibrationSnapshot::new(baseline, sample))
239    }
240
241    /// Check if online statistics are being tracked.
242    ///
243    /// Returns `true` if `add_batch_with_conversion` has been used.
244    pub fn has_stats_tracking(&self) -> bool {
245        self.ns_per_tick.is_some() && self.baseline_stats.count() > 0
246    }
247
248    /// Reset the state for a new test run.
249    ///
250    /// Clears all samples, posteriors, and statistics while preserving capacity.
251    pub fn reset(&mut self) {
252        self.baseline_samples.clear();
253        self.sample_samples.clear();
254        self.previous_posterior = None;
255        self.recent_kl_divergences.clear();
256        self.batch_count = 0;
257        self.baseline_stats = OnlineStats::new();
258        self.sample_stats = OnlineStats::new();
259        self.ns_per_tick = None;
260    }
261}
262
263impl Default for AdaptiveState {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::types::{Matrix9, Vector9};
273
274    fn make_test_posterior(leak_prob: f64, n: usize) -> Posterior {
275        Posterior::new(
276            Vector9::zeros(),
277            Matrix9::identity(),
278            Vec::new(), // delta_draws
279            leak_prob,
280            100.0, // theta
281            n,
282        )
283    }
284
285    #[test]
286    fn test_adaptive_state_new() {
287        let state = AdaptiveState::new();
288        assert_eq!(state.n_total(), 0);
289        assert_eq!(state.batch_count, 0);
290        assert!(state.previous_posterior.is_none());
291        assert!(!state.has_kl_history());
292    }
293
294    #[test]
295    fn test_add_batch() {
296        let mut state = AdaptiveState::new();
297        state.add_batch(vec![100, 101, 102], vec![200, 201, 202]);
298
299        assert_eq!(state.n_total(), 3);
300        assert_eq!(state.batch_count, 1);
301        assert_eq!(state.baseline_samples, vec![100, 101, 102]);
302        assert_eq!(state.sample_samples, vec![200, 201, 202]);
303    }
304
305    #[test]
306    fn test_kl_history() {
307        let mut state = AdaptiveState::new();
308
309        for i in 0..5 {
310            state.update_kl(0.1 * (i + 1) as f64);
311        }
312
313        assert!(state.has_kl_history());
314        assert!((state.recent_kl_sum() - 1.5).abs() < 1e-10); // 0.1 + 0.2 + 0.3 + 0.4 + 0.5
315
316        // Adding one more should evict the oldest
317        state.update_kl(1.0);
318        assert!((state.recent_kl_sum() - 2.4).abs() < 1e-10); // 0.2 + 0.3 + 0.4 + 0.5 + 1.0
319    }
320
321    #[test]
322    fn test_posterior_update() {
323        let mut state = AdaptiveState::new();
324
325        let posterior1 = make_test_posterior(0.75, 1000);
326
327        // First update - no previous posterior
328        let kl1 = state.update_posterior(posterior1.clone());
329        assert_eq!(kl1, 0.0);
330        assert!(state.current_posterior().is_some());
331
332        // Second update - should compute KL
333        let posterior2 = make_test_posterior(0.80, 2000);
334        let kl2 = state.update_posterior(posterior2);
335        // KL may be 0 if posteriors are identical (same parameters)
336        assert!(kl2 >= 0.0);
337    }
338
339    #[test]
340    fn test_add_batch_with_conversion() {
341        let mut state = AdaptiveState::new();
342
343        // Add samples with ns_per_tick = 2.0 (2ns per tick)
344        state.add_batch_with_conversion(vec![100, 110, 120], vec![200, 210, 220], 2.0);
345
346        assert_eq!(state.n_total(), 3);
347        assert_eq!(state.batch_count, 1);
348        assert!(state.has_stats_tracking());
349
350        // Check that samples were stored correctly
351        assert_eq!(state.baseline_samples, vec![100, 110, 120]);
352        assert_eq!(state.sample_samples, vec![200, 210, 220]);
353    }
354
355    #[test]
356    fn test_online_stats_tracking() {
357        let mut state = AdaptiveState::new();
358
359        // Add enough samples for meaningful statistics
360        let baseline: Vec<u64> = (0..100).map(|i| 1000 + (i % 10)).collect();
361        let sample: Vec<u64> = (0..100).map(|i| 1100 + (i % 10)).collect();
362        state.add_batch_with_conversion(baseline, sample, 1.0);
363
364        // Check baseline stats
365        let baseline_stats = state.baseline_stats().expect("Should have baseline stats");
366        assert_eq!(baseline_stats.count, 100);
367        // Mean should be around 1004.5 (0..9 has mean 4.5, plus base 1000)
368        assert!(
369            (baseline_stats.mean - 1004.5).abs() < 1.0,
370            "Baseline mean {} should be near 1004.5",
371            baseline_stats.mean
372        );
373
374        // Check sample stats
375        let sample_stats = state.sample_stats().expect("Should have sample stats");
376        assert_eq!(sample_stats.count, 100);
377        // Mean should be around 1104.5
378        assert!(
379            (sample_stats.mean - 1104.5).abs() < 1.0,
380            "Sample mean {} should be near 1104.5",
381            sample_stats.mean
382        );
383    }
384
385    #[test]
386    fn test_reset() {
387        let mut state = AdaptiveState::new();
388
389        // Add some data
390        state.add_batch_with_conversion(vec![100, 110], vec![200, 210], 1.0);
391        state.update_kl(0.5);
392        let posterior = make_test_posterior(0.75, 100);
393        state.update_posterior(posterior);
394
395        assert!(state.n_total() > 0);
396
397        // Reset
398        state.reset();
399
400        assert_eq!(state.n_total(), 0);
401        assert_eq!(state.batch_count, 0);
402        assert!(state.previous_posterior.is_none());
403        assert!(!state.has_kl_history());
404        assert!(!state.has_stats_tracking());
405    }
406
407    #[test]
408    fn test_stats_not_tracked_without_conversion() {
409        let mut state = AdaptiveState::new();
410
411        // Use regular add_batch without conversion
412        state.add_batch(vec![100, 110, 120], vec![200, 210, 220]);
413
414        // Should not have stats tracking
415        assert!(!state.has_stats_tracking());
416        assert!(state.baseline_stats().is_none());
417        assert!(state.sample_stats().is_none());
418    }
419}