1use alloc::collections::VecDeque;
10use alloc::vec::Vec;
11
12use crate::statistics::{OnlineStats, StatsSnapshot};
13
14use super::{kl_divergence_gaussian, CalibrationSnapshot, Posterior};
15
16pub struct AdaptiveState {
31 pub baseline_samples: Vec<u64>,
33
34 pub sample_samples: Vec<u64>,
36
37 pub previous_posterior: Option<Posterior>,
40
41 pub recent_kl_divergences: VecDeque<f64>,
44
45 pub batch_count: usize,
47
48 baseline_stats: OnlineStats,
50
51 sample_stats: OnlineStats,
53
54 ns_per_tick: Option<f64>,
57}
58
59impl AdaptiveState {
60 pub fn new() -> Self {
62 Self {
63 baseline_samples: Vec::new(),
64 sample_samples: Vec::new(),
65 previous_posterior: None,
66 recent_kl_divergences: VecDeque::with_capacity(5),
67 batch_count: 0,
68 baseline_stats: OnlineStats::new(),
69 sample_stats: OnlineStats::new(),
70 ns_per_tick: None,
71 }
72 }
73
74 pub fn with_capacity(expected_samples: usize) -> Self {
76 Self {
77 baseline_samples: Vec::with_capacity(expected_samples),
78 sample_samples: Vec::with_capacity(expected_samples),
79 previous_posterior: None,
80 recent_kl_divergences: VecDeque::with_capacity(5),
81 batch_count: 0,
82 baseline_stats: OnlineStats::new(),
83 sample_stats: OnlineStats::new(),
84 ns_per_tick: None,
85 }
86 }
87
88 pub fn n_total(&self) -> usize {
90 self.baseline_samples.len()
91 }
92
93 pub fn add_batch(&mut self, baseline: Vec<u64>, sample: Vec<u64>) {
99 debug_assert_eq!(
100 baseline.len(),
101 sample.len(),
102 "Baseline and sample batch sizes must match"
103 );
104 self.baseline_samples.extend(baseline);
105 self.sample_samples.extend(sample);
106 self.batch_count += 1;
107 }
108
109 pub fn add_batch_with_conversion(
119 &mut self,
120 baseline: Vec<u64>,
121 sample: Vec<u64>,
122 ns_per_tick: f64,
123 ) {
124 debug_assert_eq!(
125 baseline.len(),
126 sample.len(),
127 "Baseline and sample batch sizes must match"
128 );
129
130 self.ns_per_tick = Some(ns_per_tick);
132
133 for &t in &baseline {
135 self.baseline_stats.update(t as f64 * ns_per_tick);
136 }
137 for &t in &sample {
138 self.sample_stats.update(t as f64 * ns_per_tick);
139 }
140
141 self.baseline_samples.extend(baseline);
143 self.sample_samples.extend(sample);
144 self.batch_count += 1;
145 }
146
147 pub fn update_kl(&mut self, kl: f64) {
152 self.recent_kl_divergences.push_back(kl);
153 if self.recent_kl_divergences.len() > 5 {
154 self.recent_kl_divergences.pop_front();
155 }
156 }
157
158 pub fn recent_kl_sum(&self) -> f64 {
163 self.recent_kl_divergences.iter().sum()
164 }
165
166 pub fn has_kl_history(&self) -> bool {
168 self.recent_kl_divergences.len() >= 5
169 }
170
171 pub fn update_posterior(&mut self, new_posterior: Posterior) -> f64 {
176 let kl = if let Some(ref prev) = self.previous_posterior {
177 kl_divergence_gaussian(&new_posterior, prev)
178 } else {
179 0.0
180 };
181
182 self.previous_posterior = Some(new_posterior);
183
184 if kl.is_finite() {
185 self.update_kl(kl);
186 }
187
188 kl
189 }
190
191 pub fn current_posterior(&self) -> Option<&Posterior> {
193 self.previous_posterior.as_ref()
194 }
195
196 pub fn baseline_ns(&self, ns_per_tick: f64) -> Vec<f64> {
198 self.baseline_samples
199 .iter()
200 .map(|&t| t as f64 * ns_per_tick)
201 .collect()
202 }
203
204 pub fn sample_ns(&self, ns_per_tick: f64) -> Vec<f64> {
206 self.sample_samples
207 .iter()
208 .map(|&t| t as f64 * ns_per_tick)
209 .collect()
210 }
211
212 pub fn baseline_stats(&self) -> Option<StatsSnapshot> {
216 if self.baseline_stats.count() < 2 {
217 return None;
218 }
219 Some(self.baseline_stats.finalize())
220 }
221
222 pub fn sample_stats(&self) -> Option<StatsSnapshot> {
226 if self.sample_stats.count() < 2 {
227 return None;
228 }
229 Some(self.sample_stats.finalize())
230 }
231
232 pub fn get_stats_snapshot(&self) -> Option<CalibrationSnapshot> {
236 let baseline = self.baseline_stats()?;
237 let sample = self.sample_stats()?;
238 Some(CalibrationSnapshot::new(baseline, sample))
239 }
240
241 pub fn has_stats_tracking(&self) -> bool {
245 self.ns_per_tick.is_some() && self.baseline_stats.count() > 0
246 }
247
248 pub fn reset(&mut self) {
252 self.baseline_samples.clear();
253 self.sample_samples.clear();
254 self.previous_posterior = None;
255 self.recent_kl_divergences.clear();
256 self.batch_count = 0;
257 self.baseline_stats = OnlineStats::new();
258 self.sample_stats = OnlineStats::new();
259 self.ns_per_tick = None;
260 }
261}
262
263impl Default for AdaptiveState {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::types::{Matrix2, Matrix9, Vector2, Vector9};
273
274 fn make_test_posterior(
275 beta_proj: Vector2,
276 beta_proj_cov: Matrix2,
277 leak_prob: f64,
278 n: usize,
279 ) -> Posterior {
280 Posterior::new(
281 Vector9::zeros(),
282 Matrix9::identity(),
283 beta_proj,
284 beta_proj_cov,
285 Vec::new(), leak_prob,
287 1.0, n,
289 )
290 }
291
292 #[test]
293 fn test_adaptive_state_new() {
294 let state = AdaptiveState::new();
295 assert_eq!(state.n_total(), 0);
296 assert_eq!(state.batch_count, 0);
297 assert!(state.previous_posterior.is_none());
298 assert!(!state.has_kl_history());
299 }
300
301 #[test]
302 fn test_add_batch() {
303 let mut state = AdaptiveState::new();
304 state.add_batch(vec![100, 101, 102], vec![200, 201, 202]);
305
306 assert_eq!(state.n_total(), 3);
307 assert_eq!(state.batch_count, 1);
308 assert_eq!(state.baseline_samples, vec![100, 101, 102]);
309 assert_eq!(state.sample_samples, vec![200, 201, 202]);
310 }
311
312 #[test]
313 fn test_kl_history() {
314 let mut state = AdaptiveState::new();
315
316 for i in 0..5 {
317 state.update_kl(0.1 * (i + 1) as f64);
318 }
319
320 assert!(state.has_kl_history());
321 assert!((state.recent_kl_sum() - 1.5).abs() < 1e-10); state.update_kl(1.0);
325 assert!((state.recent_kl_sum() - 2.4).abs() < 1e-10); }
327
328 #[test]
329 fn test_posterior_update() {
330 let mut state = AdaptiveState::new();
331
332 let posterior1 = make_test_posterior(
333 Vector2::new(10.0, 5.0),
334 Matrix2::new(4.0, 0.0, 0.0, 1.0),
335 0.75,
336 1000,
337 );
338
339 let kl1 = state.update_posterior(posterior1.clone());
341 assert_eq!(kl1, 0.0);
342 assert!(state.current_posterior().is_some());
343
344 let posterior2 = make_test_posterior(
346 Vector2::new(11.0, 5.5),
347 Matrix2::new(3.5, 0.0, 0.0, 0.9),
348 0.80,
349 2000,
350 );
351 let kl2 = state.update_posterior(posterior2);
352 assert!(kl2 > 0.0); }
354
355 #[test]
356 fn test_add_batch_with_conversion() {
357 let mut state = AdaptiveState::new();
358
359 state.add_batch_with_conversion(vec![100, 110, 120], vec![200, 210, 220], 2.0);
361
362 assert_eq!(state.n_total(), 3);
363 assert_eq!(state.batch_count, 1);
364 assert!(state.has_stats_tracking());
365
366 assert_eq!(state.baseline_samples, vec![100, 110, 120]);
368 assert_eq!(state.sample_samples, vec![200, 210, 220]);
369 }
370
371 #[test]
372 fn test_online_stats_tracking() {
373 let mut state = AdaptiveState::new();
374
375 let baseline: Vec<u64> = (0..100).map(|i| 1000 + (i % 10)).collect();
377 let sample: Vec<u64> = (0..100).map(|i| 1100 + (i % 10)).collect();
378 state.add_batch_with_conversion(baseline, sample, 1.0);
379
380 let baseline_stats = state.baseline_stats().expect("Should have baseline stats");
382 assert_eq!(baseline_stats.count, 100);
383 assert!(
385 (baseline_stats.mean - 1004.5).abs() < 1.0,
386 "Baseline mean {} should be near 1004.5",
387 baseline_stats.mean
388 );
389
390 let sample_stats = state.sample_stats().expect("Should have sample stats");
392 assert_eq!(sample_stats.count, 100);
393 assert!(
395 (sample_stats.mean - 1104.5).abs() < 1.0,
396 "Sample mean {} should be near 1104.5",
397 sample_stats.mean
398 );
399 }
400
401 #[test]
402 fn test_reset() {
403 let mut state = AdaptiveState::new();
404
405 state.add_batch_with_conversion(vec![100, 110], vec![200, 210], 1.0);
407 state.update_kl(0.5);
408 let posterior = make_test_posterior(
409 Vector2::new(10.0, 5.0),
410 Matrix2::new(4.0, 0.0, 0.0, 1.0),
411 0.75,
412 100,
413 );
414 state.update_posterior(posterior);
415
416 assert!(state.n_total() > 0);
417
418 state.reset();
420
421 assert_eq!(state.n_total(), 0);
422 assert_eq!(state.batch_count, 0);
423 assert!(state.previous_posterior.is_none());
424 assert!(!state.has_kl_history());
425 assert!(!state.has_stats_tracking());
426 }
427
428 #[test]
429 fn test_stats_not_tracked_without_conversion() {
430 let mut state = AdaptiveState::new();
431
432 state.add_batch(vec![100, 110, 120], vec![200, 210, 220]);
434
435 assert!(!state.has_stats_tracking());
437 assert!(state.baseline_stats().is_none());
438 assert!(state.sample_stats().is_none());
439 }
440}