1use crate::types::{QueryTrajectory, TrajectoryStep};
6use crossbeam::queue::ArrayQueue;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Instant;
9
10pub struct TrajectoryBuffer {
12 buffer: ArrayQueue<QueryTrajectory>,
14 capacity: usize,
16 dropped: AtomicU64,
18 total_seen: AtomicU64,
20}
21
22impl TrajectoryBuffer {
23 pub fn new(capacity: usize) -> Self {
25 Self {
26 buffer: ArrayQueue::new(capacity),
27 capacity,
28 dropped: AtomicU64::new(0),
29 total_seen: AtomicU64::new(0),
30 }
31 }
32
33 pub fn record(&self, trajectory: QueryTrajectory) -> bool {
37 self.total_seen.fetch_add(1, Ordering::Relaxed);
38
39 match self.buffer.push(trajectory) {
40 Ok(()) => true,
41 Err(_) => {
42 self.dropped.fetch_add(1, Ordering::Relaxed);
43 false
44 }
45 }
46 }
47
48 pub fn pop(&self) -> Option<QueryTrajectory> {
50 self.buffer.pop()
51 }
52
53 pub fn drain(&self) -> Vec<QueryTrajectory> {
55 let mut result = Vec::with_capacity(self.len());
56 while let Some(t) = self.buffer.pop() {
57 result.push(t);
58 }
59 result
60 }
61
62 pub fn drain_n(&self, n: usize) -> Vec<QueryTrajectory> {
64 let mut result = Vec::with_capacity(n.min(self.len()));
65 for _ in 0..n {
66 match self.buffer.pop() {
67 Some(t) => result.push(t),
68 None => break,
69 }
70 }
71 result
72 }
73
74 pub fn len(&self) -> usize {
76 self.buffer.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.buffer.is_empty()
82 }
83
84 pub fn is_full(&self) -> bool {
86 self.buffer.is_full()
87 }
88
89 pub fn capacity(&self) -> usize {
91 self.capacity
92 }
93
94 pub fn dropped_count(&self) -> u64 {
96 self.dropped.load(Ordering::Relaxed)
97 }
98
99 pub fn total_seen(&self) -> u64 {
101 self.total_seen.load(Ordering::Relaxed)
102 }
103
104 pub fn success_rate(&self) -> f64 {
106 let total = self.total_seen.load(Ordering::Relaxed);
107 let dropped = self.dropped.load(Ordering::Relaxed);
108 if total == 0 {
109 1.0
110 } else {
111 (total - dropped) as f64 / total as f64
112 }
113 }
114
115 pub fn reset_stats(&self) {
117 self.dropped.store(0, Ordering::Relaxed);
118 self.total_seen.store(0, Ordering::Relaxed);
119 }
120
121 pub fn get_all(&self) -> Vec<QueryTrajectory> {
127 self.drain()
128 }
129}
130
131pub struct TrajectoryBuilder {
133 id: u64,
135 query_embedding: Vec<f32>,
137 steps: Vec<TrajectoryStep>,
139 start_time: Instant,
141 model_route: Option<String>,
143 context_ids: Vec<String>,
145}
146
147impl TrajectoryBuilder {
148 pub fn new(id: u64, query_embedding: Vec<f32>) -> Self {
150 Self {
151 id,
152 query_embedding,
153 steps: Vec::with_capacity(16),
154 start_time: Instant::now(),
155 model_route: None,
156 context_ids: Vec::new(),
157 }
158 }
159
160 pub fn add_step(&mut self, activations: Vec<f32>, attention_weights: Vec<f32>, reward: f32) {
162 let step_idx = self.steps.len();
163 self.steps.push(TrajectoryStep::new(
164 activations,
165 attention_weights,
166 reward,
167 step_idx,
168 ));
169 }
170
171 pub fn add_named_step(&mut self, name: &str, activations: Vec<f32>, attention_weights: Vec<f32>, reward: f32) {
173 let step_idx = self.steps.len();
174 self.steps.push(
175 TrajectoryStep::new(activations, attention_weights, reward, step_idx)
176 .with_layer(name)
177 );
178 }
179
180 pub fn set_model_route(&mut self, route: &str) {
182 self.model_route = Some(route.to_string());
183 }
184
185 pub fn add_context(&mut self, context_id: &str) {
187 self.context_ids.push(context_id.to_string());
188 }
189
190 pub fn step_count(&self) -> usize {
192 self.steps.len()
193 }
194
195 pub fn elapsed(&self) -> std::time::Duration {
197 self.start_time.elapsed()
198 }
199
200 pub fn build(self, final_quality: f32) -> QueryTrajectory {
202 let latency_us = self.start_time.elapsed().as_micros() as u64;
203
204 QueryTrajectory {
205 id: self.id,
206 query_embedding: self.query_embedding,
207 steps: self.steps,
208 final_quality,
209 latency_us,
210 model_route: self.model_route,
211 context_ids: self.context_ids,
212 }
213 }
214
215 pub fn build_with_latency(self, final_quality: f32, latency_us: u64) -> QueryTrajectory {
217 QueryTrajectory {
218 id: self.id,
219 query_embedding: self.query_embedding,
220 steps: self.steps,
221 final_quality,
222 latency_us,
223 model_route: self.model_route,
224 context_ids: self.context_ids,
225 }
226 }
227}
228
229pub struct TrajectoryIdGen {
231 counter: AtomicU64,
232}
233
234impl TrajectoryIdGen {
235 pub fn new() -> Self {
237 Self {
238 counter: AtomicU64::new(0),
239 }
240 }
241
242 pub fn with_start(start: u64) -> Self {
244 Self {
245 counter: AtomicU64::new(start),
246 }
247 }
248
249 pub fn next(&self) -> u64 {
251 self.counter.fetch_add(1, Ordering::Relaxed)
252 }
253
254 pub fn current(&self) -> u64 {
256 self.counter.load(Ordering::Relaxed)
257 }
258}
259
260impl Default for TrajectoryIdGen {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_buffer_basic_ops() {
272 let buffer = TrajectoryBuffer::new(10);
273
274 assert!(buffer.is_empty());
275 assert_eq!(buffer.capacity(), 10);
276
277 let trajectory = QueryTrajectory::new(1, vec![0.1, 0.2]);
278 assert!(buffer.record(trajectory));
279
280 assert_eq!(buffer.len(), 1);
281 assert!(!buffer.is_empty());
282 }
283
284 #[test]
285 fn test_buffer_overflow() {
286 let buffer = TrajectoryBuffer::new(3);
287
288 for i in 0..5 {
289 let trajectory = QueryTrajectory::new(i, vec![0.1]);
290 buffer.record(trajectory);
291 }
292
293 assert_eq!(buffer.len(), 3);
294 assert_eq!(buffer.dropped_count(), 2);
295 assert_eq!(buffer.total_seen(), 5);
296 }
297
298 #[test]
299 fn test_buffer_drain() {
300 let buffer = TrajectoryBuffer::new(10);
301
302 for i in 0..5 {
303 let trajectory = QueryTrajectory::new(i, vec![0.1]);
304 buffer.record(trajectory);
305 }
306
307 let drained = buffer.drain();
308 assert_eq!(drained.len(), 5);
309 assert!(buffer.is_empty());
310 }
311
312 #[test]
313 fn test_buffer_drain_n() {
314 let buffer = TrajectoryBuffer::new(10);
315
316 for i in 0..5 {
317 let trajectory = QueryTrajectory::new(i, vec![0.1]);
318 buffer.record(trajectory);
319 }
320
321 let partial = buffer.drain_n(3);
322 assert_eq!(partial.len(), 3);
323 assert_eq!(buffer.len(), 2);
324 }
325
326 #[test]
327 fn test_builder() {
328 let mut builder = TrajectoryBuilder::new(42, vec![0.1, 0.2, 0.3]);
329
330 builder.add_step(vec![0.5], vec![0.4, 0.6], 0.7);
331 builder.add_step(vec![0.6], vec![0.3, 0.7], 0.8);
332 builder.set_model_route("llama-7b");
333 builder.add_context("ctx-123");
334
335 assert_eq!(builder.step_count(), 2);
336
337 let trajectory = builder.build(0.85);
338
339 assert_eq!(trajectory.id, 42);
340 assert_eq!(trajectory.steps.len(), 2);
341 assert_eq!(trajectory.final_quality, 0.85);
342 assert_eq!(trajectory.model_route, Some("llama-7b".to_string()));
343 assert!(trajectory.latency_us > 0);
344 }
345
346 #[test]
347 fn test_id_generator() {
348 let gen = TrajectoryIdGen::new();
349
350 assert_eq!(gen.next(), 0);
351 assert_eq!(gen.next(), 1);
352 assert_eq!(gen.next(), 2);
353 assert_eq!(gen.current(), 3);
354 }
355
356 #[test]
357 fn test_success_rate() {
358 let buffer = TrajectoryBuffer::new(2);
359
360 for i in 0..4 {
361 buffer.record(QueryTrajectory::new(i, vec![]));
362 }
363
364 assert!((buffer.success_rate() - 0.5).abs() < 1e-6);
365 }
366}