1use std::collections::HashMap;
7use std::sync::Arc;
8
9use serde_json::Value;
10use tokio::sync::Mutex;
11use uuid::Uuid;
12
13use super::error::TracingError;
14use super::hooks::{HookContext, HookEvent, HookManager};
15use super::models::{
16 EventReward, MarkovBlanketMessage, MessageContent, OutcomeReward, SessionTimeStep,
17 SessionTrace, TimeRecord, TracingEvent,
18};
19use super::storage::{QueryParams, TraceStorage};
20
21pub struct SessionTracer {
42 storage: Arc<dyn TraceStorage>,
44 hooks: Mutex<HookManager>,
46 current_session: Mutex<Option<SessionTrace>>,
48 current_step: Mutex<Option<CurrentStep>>,
50 auto_save: bool,
52}
53
54struct CurrentStep {
56 step: SessionTimeStep,
57 db_id: Option<i64>,
58}
59
60impl SessionTracer {
61 pub fn new(storage: Arc<dyn TraceStorage>) -> Self {
63 Self {
64 storage,
65 hooks: Mutex::new(HookManager::new()),
66 current_session: Mutex::new(None),
67 current_step: Mutex::new(None),
68 auto_save: true,
69 }
70 }
71
72 pub fn with_hooks(storage: Arc<dyn TraceStorage>, hooks: HookManager) -> Self {
74 Self {
75 storage,
76 hooks: Mutex::new(hooks),
77 current_session: Mutex::new(None),
78 current_step: Mutex::new(None),
79 auto_save: true,
80 }
81 }
82
83 pub fn set_auto_save(&mut self, auto_save: bool) {
85 self.auto_save = auto_save;
86 }
87
88 pub async fn register_hook(
90 &self,
91 event: HookEvent,
92 callback: super::hooks::HookCallback,
93 priority: i32,
94 ) {
95 let mut hooks = self.hooks.lock().await;
96 hooks.register(event, callback, priority);
97 }
98
99 pub async fn start_session(
114 &self,
115 session_id: Option<&str>,
116 metadata: HashMap<String, Value>,
117 ) -> Result<String, TracingError> {
118 let mut current = self.current_session.lock().await;
119
120 if current.is_some() {
121 return Err(TracingError::SessionAlreadyActive(
122 current.as_ref().unwrap().session_id.clone(),
123 ));
124 }
125
126 let session_id = session_id
127 .map(|s| s.to_string())
128 .unwrap_or_else(|| Uuid::new_v4().to_string());
129
130 let mut trace = SessionTrace::new(&session_id);
131 trace.metadata = metadata;
132
133 if self.auto_save {
135 self.storage
136 .ensure_session(
137 &session_id,
138 trace.created_at,
139 &serde_json::to_value(&trace.metadata).unwrap_or_default(),
140 )
141 .await?;
142 }
143
144 let context = HookContext::new().with_session(&session_id);
146 self.hooks
147 .lock()
148 .await
149 .trigger(HookEvent::SessionStart, &context);
150
151 *current = Some(trace);
152
153 Ok(session_id)
154 }
155
156 pub async fn end_session(&self, save: bool) -> Result<SessionTrace, TracingError> {
166 if self.current_step.lock().await.is_some() {
168 self.end_timestep().await?;
169 }
170
171 let mut current = self.current_session.lock().await;
172
173 let trace = current.take().ok_or(TracingError::NoActiveSession)?;
174
175 if self.auto_save || save {
177 self.storage
178 .update_session_counts(&trace.session_id)
179 .await?;
180 }
181
182 let context = HookContext::new().with_session(&trace.session_id);
184 self.hooks
185 .lock()
186 .await
187 .trigger(HookEvent::SessionEnd, &context);
188
189 Ok(trace)
190 }
191
192 pub async fn current_session_id(&self) -> Option<String> {
194 self.current_session
195 .lock()
196 .await
197 .as_ref()
198 .map(|s| s.session_id.clone())
199 }
200
201 pub async fn query(&self, sql: &str, params: QueryParams) -> Result<Vec<Value>, TracingError> {
203 self.storage.query(sql, params).await
204 }
205
206 pub async fn start_timestep(
218 &self,
219 step_id: &str,
220 turn_number: Option<i32>,
221 metadata: HashMap<String, Value>,
222 ) -> Result<(), TracingError> {
223 let (session_id, step_index) = {
224 let session_guard = self.current_session.lock().await;
225 let session = session_guard
226 .as_ref()
227 .ok_or(TracingError::NoActiveSession)?;
228 (
229 session.session_id.clone(),
230 session.session_time_steps.len() as i32,
231 )
232 };
233
234 if self.current_step.lock().await.is_some() {
236 self.end_timestep().await?;
237 }
238
239 let mut step = SessionTimeStep::new(step_id, step_index);
240 step.turn_number = turn_number;
241 step.step_metadata = metadata;
242
243 let db_id = if self.auto_save {
245 Some(self.storage.ensure_timestep(&session_id, &step).await?)
246 } else {
247 None
248 };
249
250 let context = HookContext::new()
252 .with_session(&session_id)
253 .with_step(step_id);
254 self.hooks
255 .lock()
256 .await
257 .trigger(HookEvent::TimestepStart, &context);
258
259 *self.current_step.lock().await = Some(CurrentStep { step, db_id });
260
261 Ok(())
262 }
263
264 pub async fn end_timestep(&self) -> Result<(), TracingError> {
266 let session_id = self
267 .current_session_id()
268 .await
269 .ok_or(TracingError::NoActiveSession)?;
270
271 let mut current_step = self.current_step.lock().await;
272 let mut step_data = current_step.take().ok_or(TracingError::NoActiveTimestep)?;
273
274 step_data.step.complete();
275
276 if self.auto_save {
278 self.storage
279 .update_timestep(
280 &session_id,
281 &step_data.step.step_id,
282 step_data.step.completed_at,
283 )
284 .await?;
285 }
286
287 let context = HookContext::new()
289 .with_session(&session_id)
290 .with_step(&step_data.step.step_id);
291 self.hooks
292 .lock()
293 .await
294 .trigger(HookEvent::TimestepEnd, &context);
295
296 let mut session = self.current_session.lock().await;
298 if let Some(ref mut s) = *session {
299 s.session_time_steps.push(step_data.step);
300 }
301
302 Ok(())
303 }
304
305 pub async fn current_step_id(&self) -> Option<String> {
307 self.current_step
308 .lock()
309 .await
310 .as_ref()
311 .map(|s| s.step.step_id.clone())
312 }
313
314 pub async fn record_event(&self, event: TracingEvent) -> Result<Option<i64>, TracingError> {
328 let session_id = self
329 .current_session_id()
330 .await
331 .ok_or(TracingError::NoActiveSession)?;
332
333 let timestep_db_id = self
334 .current_step
335 .lock()
336 .await
337 .as_ref()
338 .and_then(|s| s.db_id);
339
340 let event_id = if self.auto_save {
342 Some(
343 self.storage
344 .insert_event(&session_id, timestep_db_id, &event)
345 .await?,
346 )
347 } else {
348 None
349 };
350
351 let context = HookContext::new()
353 .with_session(&session_id)
354 .with_event(event.clone());
355 self.hooks
356 .lock()
357 .await
358 .trigger(HookEvent::EventRecorded, &context);
359
360 let mut session = self.current_session.lock().await;
362 if let Some(ref mut s) = *session {
363 s.event_history.push(event.clone());
364 }
365
366 let mut step = self.current_step.lock().await;
368 if let Some(ref mut s) = *step {
369 s.step.events.push(event);
370 }
371
372 Ok(event_id)
373 }
374
375 pub async fn record_message(
391 &self,
392 content: MessageContent,
393 message_type: &str,
394 metadata: HashMap<String, Value>,
395 ) -> Result<Option<i64>, TracingError> {
396 let session_id = self
397 .current_session_id()
398 .await
399 .ok_or(TracingError::NoActiveSession)?;
400
401 let timestep_db_id = self
402 .current_step
403 .lock()
404 .await
405 .as_ref()
406 .and_then(|s| s.db_id);
407
408 let msg = MarkovBlanketMessage {
409 content,
410 message_type: message_type.to_string(),
411 time_record: TimeRecord::now(),
412 metadata,
413 };
414
415 let msg_id = if self.auto_save {
417 Some(
418 self.storage
419 .insert_message(&session_id, timestep_db_id, &msg)
420 .await?,
421 )
422 } else {
423 None
424 };
425
426 let context = HookContext::new()
428 .with_session(&session_id)
429 .with_message(msg.clone());
430 self.hooks
431 .lock()
432 .await
433 .trigger(HookEvent::MessageRecorded, &context);
434
435 let mut session = self.current_session.lock().await;
437 if let Some(ref mut s) = *session {
438 s.markov_blanket_message_history.push(msg.clone());
439 }
440
441 let mut step = self.current_step.lock().await;
443 if let Some(ref mut s) = *step {
444 s.step.markov_blanket_messages.push(msg);
445 }
446
447 Ok(msg_id)
448 }
449
450 pub async fn record_outcome_reward(
456 &self,
457 reward: OutcomeReward,
458 ) -> Result<Option<i64>, TracingError> {
459 let session_id = self
460 .current_session_id()
461 .await
462 .ok_or(TracingError::NoActiveSession)?;
463
464 let reward_id = if self.auto_save {
465 Some(
466 self.storage
467 .insert_outcome_reward(&session_id, &reward)
468 .await?,
469 )
470 } else {
471 None
472 };
473
474 Ok(reward_id)
475 }
476
477 pub async fn record_event_reward(
479 &self,
480 event_id: i64,
481 reward: EventReward,
482 ) -> Result<Option<i64>, TracingError> {
483 let session_id = self
484 .current_session_id()
485 .await
486 .ok_or(TracingError::NoActiveSession)?;
487
488 let turn_number = self
489 .current_step
490 .lock()
491 .await
492 .as_ref()
493 .and_then(|s| s.step.turn_number);
494
495 let reward_id = if self.auto_save {
496 Some(
497 self.storage
498 .insert_event_reward(&session_id, event_id, None, turn_number, &reward)
499 .await?,
500 )
501 } else {
502 None
503 };
504
505 Ok(reward_id)
506 }
507
508 pub async fn get_session(
514 &self,
515 session_id: &str,
516 ) -> Result<Option<SessionTrace>, TracingError> {
517 self.storage.get_session(session_id).await
518 }
519
520 pub async fn delete_session(&self, session_id: &str) -> Result<bool, TracingError> {
522 self.storage.delete_session(session_id).await
523 }
524}
525
526#[cfg(all(test, feature = "libsql"))]
527mod tests {
528 use super::*;
529 use crate::tracing::libsql_storage::LibsqlTraceStorage;
530 use crate::tracing::models::{BaseEventFields, LMCAISEvent};
531
532 async fn create_test_tracer() -> SessionTracer {
533 let storage = Arc::new(LibsqlTraceStorage::new_memory().await.unwrap());
534 SessionTracer::new(storage)
535 }
536
537 #[tokio::test]
538 async fn test_session_lifecycle() {
539 let tracer = create_test_tracer().await;
540
541 let session_id = tracer
543 .start_session(None, Default::default())
544 .await
545 .unwrap();
546 assert!(!session_id.is_empty());
547 assert_eq!(tracer.current_session_id().await, Some(session_id.clone()));
548
549 let trace = tracer.end_session(true).await.unwrap();
551 assert_eq!(trace.session_id, session_id);
552 assert!(tracer.current_session_id().await.is_none());
553 }
554
555 #[tokio::test]
556 async fn test_timestep_lifecycle() {
557 let tracer = create_test_tracer().await;
558
559 tracer
560 .start_session(None, Default::default())
561 .await
562 .unwrap();
563
564 tracer
566 .start_timestep("step-1", Some(1), Default::default())
567 .await
568 .unwrap();
569 assert_eq!(tracer.current_step_id().await, Some("step-1".to_string()));
570
571 tracer.end_timestep().await.unwrap();
573 assert!(tracer.current_step_id().await.is_none());
574
575 let trace = tracer.end_session(true).await.unwrap();
576 assert_eq!(trace.session_time_steps.len(), 1);
577 }
578
579 #[tokio::test]
580 async fn test_event_recording() {
581 let tracer = create_test_tracer().await;
582
583 tracer
584 .start_session(None, Default::default())
585 .await
586 .unwrap();
587 tracer
588 .start_timestep("step-1", Some(1), Default::default())
589 .await
590 .unwrap();
591
592 let event = TracingEvent::Cais(LMCAISEvent {
594 base: BaseEventFields::new("test-system"),
595 model_name: "gpt-4".to_string(),
596 provider: Some("openai".to_string()),
597 input_tokens: Some(100),
598 output_tokens: Some(50),
599 ..Default::default()
600 });
601
602 let event_id = tracer.record_event(event).await.unwrap();
603 assert!(event_id.is_some());
604
605 tracer.end_timestep().await.unwrap();
606 let trace = tracer.end_session(true).await.unwrap();
607
608 assert_eq!(trace.event_history.len(), 1);
609 assert_eq!(trace.session_time_steps[0].events.len(), 1);
610 }
611
612 #[tokio::test]
613 async fn test_message_recording() {
614 let tracer = create_test_tracer().await;
615
616 tracer
617 .start_session(None, Default::default())
618 .await
619 .unwrap();
620 tracer
621 .start_timestep("step-1", Some(1), Default::default())
622 .await
623 .unwrap();
624
625 let content = MessageContent::from_text("Hello, world!");
627 let msg_id = tracer
628 .record_message(content, "user", Default::default())
629 .await
630 .unwrap();
631 assert!(msg_id.is_some());
632
633 tracer.end_timestep().await.unwrap();
634 let trace = tracer.end_session(true).await.unwrap();
635
636 assert_eq!(trace.markov_blanket_message_history.len(), 1);
637 }
638
639 #[tokio::test]
640 async fn test_custom_session_id() {
641 let tracer = create_test_tracer().await;
642
643 let session_id = tracer
644 .start_session(Some("my-custom-id"), Default::default())
645 .await
646 .unwrap();
647
648 assert_eq!(session_id, "my-custom-id");
649 }
650
651 #[tokio::test]
652 async fn test_session_retrieval() {
653 let tracer = create_test_tracer().await;
654
655 let session_id = tracer
656 .start_session(None, Default::default())
657 .await
658 .unwrap();
659 tracer.end_session(true).await.unwrap();
660
661 let retrieved = tracer.get_session(&session_id).await.unwrap();
663 assert!(retrieved.is_some());
664 assert_eq!(retrieved.unwrap().session_id, session_id);
665 }
666
667 #[tokio::test]
668 async fn test_no_duplicate_sessions() {
669 let tracer = create_test_tracer().await;
670
671 tracer
672 .start_session(None, Default::default())
673 .await
674 .unwrap();
675
676 let result = tracer.start_session(None, Default::default()).await;
678 assert!(matches!(result, Err(TracingError::SessionAlreadyActive(_))));
679 }
680}