1use alloc::collections::VecDeque;
10use alloc::vec::Vec;
11
12use crate::statistics::{OnlineStats, StatsSnapshot};
13
14use super::{kl_divergence_gaussian, CalibrationSnapshot, Posterior};
15
16pub struct AdaptiveState {
31 pub baseline_samples: Vec<u64>,
33
34 pub sample_samples: Vec<u64>,
36
37 pub previous_posterior: Option<Posterior>,
40
41 pub recent_kl_divergences: VecDeque<f64>,
44
45 pub batch_count: usize,
47
48 baseline_stats: OnlineStats,
50
51 sample_stats: OnlineStats,
53
54 ns_per_tick: Option<f64>,
57}
58
59impl AdaptiveState {
60 pub fn new() -> Self {
62 Self {
63 baseline_samples: Vec::new(),
64 sample_samples: Vec::new(),
65 previous_posterior: None,
66 recent_kl_divergences: VecDeque::with_capacity(5),
67 batch_count: 0,
68 baseline_stats: OnlineStats::new(),
69 sample_stats: OnlineStats::new(),
70 ns_per_tick: None,
71 }
72 }
73
74 pub fn with_capacity(expected_samples: usize) -> Self {
76 Self {
77 baseline_samples: Vec::with_capacity(expected_samples),
78 sample_samples: Vec::with_capacity(expected_samples),
79 previous_posterior: None,
80 recent_kl_divergences: VecDeque::with_capacity(5),
81 batch_count: 0,
82 baseline_stats: OnlineStats::new(),
83 sample_stats: OnlineStats::new(),
84 ns_per_tick: None,
85 }
86 }
87
88 pub fn n_total(&self) -> usize {
90 self.baseline_samples.len()
91 }
92
93 pub fn add_batch(&mut self, baseline: Vec<u64>, sample: Vec<u64>) {
99 debug_assert_eq!(
100 baseline.len(),
101 sample.len(),
102 "Baseline and sample batch sizes must match"
103 );
104 self.baseline_samples.extend(baseline);
105 self.sample_samples.extend(sample);
106 self.batch_count += 1;
107 }
108
109 pub fn add_batch_with_conversion(
119 &mut self,
120 baseline: Vec<u64>,
121 sample: Vec<u64>,
122 ns_per_tick: f64,
123 ) {
124 debug_assert_eq!(
125 baseline.len(),
126 sample.len(),
127 "Baseline and sample batch sizes must match"
128 );
129
130 self.ns_per_tick = Some(ns_per_tick);
132
133 for &t in &baseline {
135 self.baseline_stats.update(t as f64 * ns_per_tick);
136 }
137 for &t in &sample {
138 self.sample_stats.update(t as f64 * ns_per_tick);
139 }
140
141 self.baseline_samples.extend(baseline);
143 self.sample_samples.extend(sample);
144 self.batch_count += 1;
145 }
146
147 pub fn update_kl(&mut self, kl: f64) {
152 self.recent_kl_divergences.push_back(kl);
153 if self.recent_kl_divergences.len() > 5 {
154 self.recent_kl_divergences.pop_front();
155 }
156 }
157
158 pub fn recent_kl_sum(&self) -> f64 {
163 self.recent_kl_divergences.iter().sum()
164 }
165
166 pub fn has_kl_history(&self) -> bool {
168 self.recent_kl_divergences.len() >= 5
169 }
170
171 pub fn update_posterior(&mut self, new_posterior: Posterior) -> f64 {
176 let kl = if let Some(ref prev) = self.previous_posterior {
177 kl_divergence_gaussian(&new_posterior, prev)
178 } else {
179 0.0
180 };
181
182 self.previous_posterior = Some(new_posterior);
183
184 if kl.is_finite() {
185 self.update_kl(kl);
186 }
187
188 kl
189 }
190
191 pub fn current_posterior(&self) -> Option<&Posterior> {
193 self.previous_posterior.as_ref()
194 }
195
196 pub fn baseline_ns(&self, ns_per_tick: f64) -> Vec<f64> {
198 self.baseline_samples
199 .iter()
200 .map(|&t| t as f64 * ns_per_tick)
201 .collect()
202 }
203
204 pub fn sample_ns(&self, ns_per_tick: f64) -> Vec<f64> {
206 self.sample_samples
207 .iter()
208 .map(|&t| t as f64 * ns_per_tick)
209 .collect()
210 }
211
212 pub fn baseline_stats(&self) -> Option<StatsSnapshot> {
216 if self.baseline_stats.count() < 2 {
217 return None;
218 }
219 Some(self.baseline_stats.finalize())
220 }
221
222 pub fn sample_stats(&self) -> Option<StatsSnapshot> {
226 if self.sample_stats.count() < 2 {
227 return None;
228 }
229 Some(self.sample_stats.finalize())
230 }
231
232 pub fn get_stats_snapshot(&self) -> Option<CalibrationSnapshot> {
236 let baseline = self.baseline_stats()?;
237 let sample = self.sample_stats()?;
238 Some(CalibrationSnapshot::new(baseline, sample))
239 }
240
241 pub fn has_stats_tracking(&self) -> bool {
245 self.ns_per_tick.is_some() && self.baseline_stats.count() > 0
246 }
247
248 pub fn reset(&mut self) {
252 self.baseline_samples.clear();
253 self.sample_samples.clear();
254 self.previous_posterior = None;
255 self.recent_kl_divergences.clear();
256 self.batch_count = 0;
257 self.baseline_stats = OnlineStats::new();
258 self.sample_stats = OnlineStats::new();
259 self.ns_per_tick = None;
260 }
261}
262
263impl Default for AdaptiveState {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::types::{Matrix9, Vector9};
273
274 fn make_test_posterior(leak_prob: f64, n: usize) -> Posterior {
275 Posterior::new(
276 Vector9::zeros(),
277 Matrix9::identity(),
278 Vec::new(), leak_prob,
280 100.0, n,
282 )
283 }
284
285 #[test]
286 fn test_adaptive_state_new() {
287 let state = AdaptiveState::new();
288 assert_eq!(state.n_total(), 0);
289 assert_eq!(state.batch_count, 0);
290 assert!(state.previous_posterior.is_none());
291 assert!(!state.has_kl_history());
292 }
293
294 #[test]
295 fn test_add_batch() {
296 let mut state = AdaptiveState::new();
297 state.add_batch(vec![100, 101, 102], vec![200, 201, 202]);
298
299 assert_eq!(state.n_total(), 3);
300 assert_eq!(state.batch_count, 1);
301 assert_eq!(state.baseline_samples, vec![100, 101, 102]);
302 assert_eq!(state.sample_samples, vec![200, 201, 202]);
303 }
304
305 #[test]
306 fn test_kl_history() {
307 let mut state = AdaptiveState::new();
308
309 for i in 0..5 {
310 state.update_kl(0.1 * (i + 1) as f64);
311 }
312
313 assert!(state.has_kl_history());
314 assert!((state.recent_kl_sum() - 1.5).abs() < 1e-10); state.update_kl(1.0);
318 assert!((state.recent_kl_sum() - 2.4).abs() < 1e-10); }
320
321 #[test]
322 fn test_posterior_update() {
323 let mut state = AdaptiveState::new();
324
325 let posterior1 = make_test_posterior(0.75, 1000);
326
327 let kl1 = state.update_posterior(posterior1.clone());
329 assert_eq!(kl1, 0.0);
330 assert!(state.current_posterior().is_some());
331
332 let posterior2 = make_test_posterior(0.80, 2000);
334 let kl2 = state.update_posterior(posterior2);
335 assert!(kl2 >= 0.0);
337 }
338
339 #[test]
340 fn test_add_batch_with_conversion() {
341 let mut state = AdaptiveState::new();
342
343 state.add_batch_with_conversion(vec![100, 110, 120], vec![200, 210, 220], 2.0);
345
346 assert_eq!(state.n_total(), 3);
347 assert_eq!(state.batch_count, 1);
348 assert!(state.has_stats_tracking());
349
350 assert_eq!(state.baseline_samples, vec![100, 110, 120]);
352 assert_eq!(state.sample_samples, vec![200, 210, 220]);
353 }
354
355 #[test]
356 fn test_online_stats_tracking() {
357 let mut state = AdaptiveState::new();
358
359 let baseline: Vec<u64> = (0..100).map(|i| 1000 + (i % 10)).collect();
361 let sample: Vec<u64> = (0..100).map(|i| 1100 + (i % 10)).collect();
362 state.add_batch_with_conversion(baseline, sample, 1.0);
363
364 let baseline_stats = state.baseline_stats().expect("Should have baseline stats");
366 assert_eq!(baseline_stats.count, 100);
367 assert!(
369 (baseline_stats.mean - 1004.5).abs() < 1.0,
370 "Baseline mean {} should be near 1004.5",
371 baseline_stats.mean
372 );
373
374 let sample_stats = state.sample_stats().expect("Should have sample stats");
376 assert_eq!(sample_stats.count, 100);
377 assert!(
379 (sample_stats.mean - 1104.5).abs() < 1.0,
380 "Sample mean {} should be near 1104.5",
381 sample_stats.mean
382 );
383 }
384
385 #[test]
386 fn test_reset() {
387 let mut state = AdaptiveState::new();
388
389 state.add_batch_with_conversion(vec![100, 110], vec![200, 210], 1.0);
391 state.update_kl(0.5);
392 let posterior = make_test_posterior(0.75, 100);
393 state.update_posterior(posterior);
394
395 assert!(state.n_total() > 0);
396
397 state.reset();
399
400 assert_eq!(state.n_total(), 0);
401 assert_eq!(state.batch_count, 0);
402 assert!(state.previous_posterior.is_none());
403 assert!(!state.has_kl_history());
404 assert!(!state.has_stats_tracking());
405 }
406
407 #[test]
408 fn test_stats_not_tracked_without_conversion() {
409 let mut state = AdaptiveState::new();
410
411 state.add_batch(vec![100, 110, 120], vec![200, 210, 220]);
413
414 assert!(!state.has_stats_tracking());
416 assert!(state.baseline_stats().is_none());
417 assert!(state.sample_stats().is_none());
418 }
419}