1use std::collections::HashMap;
7use std::sync::Arc;
8
9use serde_json::Value;
10use tokio::sync::Mutex;
11use uuid::Uuid;
12
13use super::error::TracingError;
14use super::hooks::{HookContext, HookEvent, HookManager};
15use super::models::{
16 EventReward, MarkovBlanketMessage, MessageContent, OutcomeReward, SessionTimeStep,
17 SessionTrace, TimeRecord, TracingEvent,
18};
19use super::storage::TraceStorage;
20
21pub struct SessionTracer {
42 storage: Arc<dyn TraceStorage>,
44 hooks: Mutex<HookManager>,
46 current_session: Mutex<Option<SessionTrace>>,
48 current_step: Mutex<Option<CurrentStep>>,
50 auto_save: bool,
52}
53
54struct CurrentStep {
56 step: SessionTimeStep,
57 db_id: Option<i64>,
58}
59
60impl SessionTracer {
61 pub fn new(storage: Arc<dyn TraceStorage>) -> Self {
63 Self {
64 storage,
65 hooks: Mutex::new(HookManager::new()),
66 current_session: Mutex::new(None),
67 current_step: Mutex::new(None),
68 auto_save: true,
69 }
70 }
71
72 pub fn with_hooks(storage: Arc<dyn TraceStorage>, hooks: HookManager) -> Self {
74 Self {
75 storage,
76 hooks: Mutex::new(hooks),
77 current_session: Mutex::new(None),
78 current_step: Mutex::new(None),
79 auto_save: true,
80 }
81 }
82
83 pub fn set_auto_save(&mut self, auto_save: bool) {
85 self.auto_save = auto_save;
86 }
87
88 pub async fn register_hook(
90 &self,
91 event: HookEvent,
92 callback: super::hooks::HookCallback,
93 priority: i32,
94 ) {
95 let mut hooks = self.hooks.lock().await;
96 hooks.register(event, callback, priority);
97 }
98
99 pub async fn start_session(
114 &self,
115 session_id: Option<&str>,
116 metadata: HashMap<String, Value>,
117 ) -> Result<String, TracingError> {
118 let mut current = self.current_session.lock().await;
119
120 if current.is_some() {
121 return Err(TracingError::SessionAlreadyActive(
122 current.as_ref().unwrap().session_id.clone(),
123 ));
124 }
125
126 let session_id = session_id
127 .map(|s| s.to_string())
128 .unwrap_or_else(|| Uuid::new_v4().to_string());
129
130 let mut trace = SessionTrace::new(&session_id);
131 trace.metadata = metadata;
132
133 if self.auto_save {
135 self.storage
136 .ensure_session(
137 &session_id,
138 trace.created_at,
139 &serde_json::to_value(&trace.metadata).unwrap_or_default(),
140 )
141 .await?;
142 }
143
144 let context = HookContext::new().with_session(&session_id);
146 self.hooks.lock().await.trigger(HookEvent::SessionStart, &context);
147
148 *current = Some(trace);
149
150 Ok(session_id)
151 }
152
153 pub async fn end_session(&self, save: bool) -> Result<SessionTrace, TracingError> {
163 if self.current_step.lock().await.is_some() {
165 self.end_timestep().await?;
166 }
167
168 let mut current = self.current_session.lock().await;
169
170 let trace = current.take().ok_or(TracingError::NoActiveSession)?;
171
172 if self.auto_save || save {
174 self.storage.update_session_counts(&trace.session_id).await?;
175 }
176
177 let context = HookContext::new().with_session(&trace.session_id);
179 self.hooks.lock().await.trigger(HookEvent::SessionEnd, &context);
180
181 Ok(trace)
182 }
183
184 pub async fn current_session_id(&self) -> Option<String> {
186 self.current_session
187 .lock()
188 .await
189 .as_ref()
190 .map(|s| s.session_id.clone())
191 }
192
193 pub async fn start_timestep(
205 &self,
206 step_id: &str,
207 turn_number: Option<i32>,
208 metadata: HashMap<String, Value>,
209 ) -> Result<(), TracingError> {
210 let (session_id, step_index) = {
211 let session_guard = self.current_session.lock().await;
212 let session = session_guard.as_ref().ok_or(TracingError::NoActiveSession)?;
213 (session.session_id.clone(), session.session_time_steps.len() as i32)
214 };
215
216 if self.current_step.lock().await.is_some() {
218 self.end_timestep().await?;
219 }
220
221 let mut step = SessionTimeStep::new(step_id, step_index);
222 step.turn_number = turn_number;
223 step.step_metadata = metadata;
224
225 let db_id = if self.auto_save {
227 Some(self.storage.ensure_timestep(&session_id, &step).await?)
228 } else {
229 None
230 };
231
232 let context = HookContext::new()
234 .with_session(&session_id)
235 .with_step(step_id);
236 self.hooks.lock().await.trigger(HookEvent::TimestepStart, &context);
237
238 *self.current_step.lock().await = Some(CurrentStep { step, db_id });
239
240 Ok(())
241 }
242
243 pub async fn end_timestep(&self) -> Result<(), TracingError> {
245 let session_id = self.current_session_id().await.ok_or(TracingError::NoActiveSession)?;
246
247 let mut current_step = self.current_step.lock().await;
248 let mut step_data = current_step.take().ok_or(TracingError::NoActiveTimestep)?;
249
250 step_data.step.complete();
251
252 if self.auto_save {
254 self.storage
255 .update_timestep(&session_id, &step_data.step.step_id, step_data.step.completed_at)
256 .await?;
257 }
258
259 let context = HookContext::new()
261 .with_session(&session_id)
262 .with_step(&step_data.step.step_id);
263 self.hooks.lock().await.trigger(HookEvent::TimestepEnd, &context);
264
265 let mut session = self.current_session.lock().await;
267 if let Some(ref mut s) = *session {
268 s.session_time_steps.push(step_data.step);
269 }
270
271 Ok(())
272 }
273
274 pub async fn current_step_id(&self) -> Option<String> {
276 self.current_step
277 .lock()
278 .await
279 .as_ref()
280 .map(|s| s.step.step_id.clone())
281 }
282
283 pub async fn record_event(&self, event: TracingEvent) -> Result<Option<i64>, TracingError> {
297 let session_id = self.current_session_id().await.ok_or(TracingError::NoActiveSession)?;
298
299 let timestep_db_id = self.current_step.lock().await.as_ref().and_then(|s| s.db_id);
300
301 let event_id = if self.auto_save {
303 Some(
304 self.storage
305 .insert_event(&session_id, timestep_db_id, &event)
306 .await?,
307 )
308 } else {
309 None
310 };
311
312 let context = HookContext::new()
314 .with_session(&session_id)
315 .with_event(event.clone());
316 self.hooks.lock().await.trigger(HookEvent::EventRecorded, &context);
317
318 let mut session = self.current_session.lock().await;
320 if let Some(ref mut s) = *session {
321 s.event_history.push(event.clone());
322 }
323
324 let mut step = self.current_step.lock().await;
326 if let Some(ref mut s) = *step {
327 s.step.events.push(event);
328 }
329
330 Ok(event_id)
331 }
332
333 pub async fn record_message(
349 &self,
350 content: MessageContent,
351 message_type: &str,
352 metadata: HashMap<String, Value>,
353 ) -> Result<Option<i64>, TracingError> {
354 let session_id = self.current_session_id().await.ok_or(TracingError::NoActiveSession)?;
355
356 let timestep_db_id = self.current_step.lock().await.as_ref().and_then(|s| s.db_id);
357
358 let msg = MarkovBlanketMessage {
359 content,
360 message_type: message_type.to_string(),
361 time_record: TimeRecord::now(),
362 metadata,
363 };
364
365 let msg_id = if self.auto_save {
367 Some(
368 self.storage
369 .insert_message(&session_id, timestep_db_id, &msg)
370 .await?,
371 )
372 } else {
373 None
374 };
375
376 let context = HookContext::new()
378 .with_session(&session_id)
379 .with_message(msg.clone());
380 self.hooks.lock().await.trigger(HookEvent::MessageRecorded, &context);
381
382 let mut session = self.current_session.lock().await;
384 if let Some(ref mut s) = *session {
385 s.markov_blanket_message_history.push(msg.clone());
386 }
387
388 let mut step = self.current_step.lock().await;
390 if let Some(ref mut s) = *step {
391 s.step.markov_blanket_messages.push(msg);
392 }
393
394 Ok(msg_id)
395 }
396
397 pub async fn record_outcome_reward(
403 &self,
404 reward: OutcomeReward,
405 ) -> Result<Option<i64>, TracingError> {
406 let session_id = self.current_session_id().await.ok_or(TracingError::NoActiveSession)?;
407
408 let reward_id = if self.auto_save {
409 Some(
410 self.storage
411 .insert_outcome_reward(&session_id, &reward)
412 .await?,
413 )
414 } else {
415 None
416 };
417
418 Ok(reward_id)
419 }
420
421 pub async fn record_event_reward(
423 &self,
424 event_id: i64,
425 reward: EventReward,
426 ) -> Result<Option<i64>, TracingError> {
427 let session_id = self.current_session_id().await.ok_or(TracingError::NoActiveSession)?;
428
429 let turn_number = self
430 .current_step
431 .lock()
432 .await
433 .as_ref()
434 .and_then(|s| s.step.turn_number);
435
436 let reward_id = if self.auto_save {
437 Some(
438 self.storage
439 .insert_event_reward(&session_id, event_id, None, turn_number, &reward)
440 .await?,
441 )
442 } else {
443 None
444 };
445
446 Ok(reward_id)
447 }
448
449 pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionTrace>, TracingError> {
455 self.storage.get_session(session_id).await
456 }
457
458 pub async fn delete_session(&self, session_id: &str) -> Result<bool, TracingError> {
460 self.storage.delete_session(session_id).await
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::tracing::libsql_storage::LibsqlTraceStorage;
468 use crate::tracing::models::{BaseEventFields, LMCAISEvent};
469
470 async fn create_test_tracer() -> SessionTracer {
471 let storage = Arc::new(LibsqlTraceStorage::new_memory().await.unwrap());
472 SessionTracer::new(storage)
473 }
474
475 #[tokio::test]
476 async fn test_session_lifecycle() {
477 let tracer = create_test_tracer().await;
478
479 let session_id = tracer
481 .start_session(None, Default::default())
482 .await
483 .unwrap();
484 assert!(!session_id.is_empty());
485 assert_eq!(tracer.current_session_id().await, Some(session_id.clone()));
486
487 let trace = tracer.end_session(true).await.unwrap();
489 assert_eq!(trace.session_id, session_id);
490 assert!(tracer.current_session_id().await.is_none());
491 }
492
493 #[tokio::test]
494 async fn test_timestep_lifecycle() {
495 let tracer = create_test_tracer().await;
496
497 tracer.start_session(None, Default::default()).await.unwrap();
498
499 tracer
501 .start_timestep("step-1", Some(1), Default::default())
502 .await
503 .unwrap();
504 assert_eq!(tracer.current_step_id().await, Some("step-1".to_string()));
505
506 tracer.end_timestep().await.unwrap();
508 assert!(tracer.current_step_id().await.is_none());
509
510 let trace = tracer.end_session(true).await.unwrap();
511 assert_eq!(trace.session_time_steps.len(), 1);
512 }
513
514 #[tokio::test]
515 async fn test_event_recording() {
516 let tracer = create_test_tracer().await;
517
518 tracer.start_session(None, Default::default()).await.unwrap();
519 tracer
520 .start_timestep("step-1", Some(1), Default::default())
521 .await
522 .unwrap();
523
524 let event = TracingEvent::Cais(LMCAISEvent {
526 base: BaseEventFields::new("test-system"),
527 model_name: "gpt-4".to_string(),
528 provider: Some("openai".to_string()),
529 input_tokens: Some(100),
530 output_tokens: Some(50),
531 ..Default::default()
532 });
533
534 let event_id = tracer.record_event(event).await.unwrap();
535 assert!(event_id.is_some());
536
537 tracer.end_timestep().await.unwrap();
538 let trace = tracer.end_session(true).await.unwrap();
539
540 assert_eq!(trace.event_history.len(), 1);
541 assert_eq!(trace.session_time_steps[0].events.len(), 1);
542 }
543
544 #[tokio::test]
545 async fn test_message_recording() {
546 let tracer = create_test_tracer().await;
547
548 tracer.start_session(None, Default::default()).await.unwrap();
549 tracer
550 .start_timestep("step-1", Some(1), Default::default())
551 .await
552 .unwrap();
553
554 let content = MessageContent::from_text("Hello, world!");
556 let msg_id = tracer
557 .record_message(content, "user", Default::default())
558 .await
559 .unwrap();
560 assert!(msg_id.is_some());
561
562 tracer.end_timestep().await.unwrap();
563 let trace = tracer.end_session(true).await.unwrap();
564
565 assert_eq!(trace.markov_blanket_message_history.len(), 1);
566 }
567
568 #[tokio::test]
569 async fn test_custom_session_id() {
570 let tracer = create_test_tracer().await;
571
572 let session_id = tracer
573 .start_session(Some("my-custom-id"), Default::default())
574 .await
575 .unwrap();
576
577 assert_eq!(session_id, "my-custom-id");
578 }
579
580 #[tokio::test]
581 async fn test_session_retrieval() {
582 let tracer = create_test_tracer().await;
583
584 let session_id = tracer
585 .start_session(None, Default::default())
586 .await
587 .unwrap();
588 tracer.end_session(true).await.unwrap();
589
590 let retrieved = tracer.get_session(&session_id).await.unwrap();
592 assert!(retrieved.is_some());
593 assert_eq!(retrieved.unwrap().session_id, session_id);
594 }
595
596 #[tokio::test]
597 async fn test_no_duplicate_sessions() {
598 let tracer = create_test_tracer().await;
599
600 tracer.start_session(None, Default::default()).await.unwrap();
601
602 let result = tracer.start_session(None, Default::default()).await;
604 assert!(matches!(result, Err(TracingError::SessionAlreadyActive(_))));
605 }
606}