1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct SessionId(pub Uuid);
15
16impl SessionId {
17 pub fn new() -> Self {
19 Self(Uuid::new_v4())
20 }
21
22 pub fn from_string(s: &str) -> Result<Self, uuid::Error> {
24 Ok(Self(Uuid::parse_str(s)?))
25 }
26
27 pub fn as_str(&self) -> String {
29 self.0.to_string()
30 }
31}
32
33impl Default for SessionId {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl std::fmt::Display for SessionId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
47pub struct Timestamp(pub u64);
48
49impl Timestamp {
50 pub fn now() -> Self {
52 let duration = SystemTime::now()
53 .duration_since(UNIX_EPOCH)
54 .unwrap_or(Duration::ZERO);
55 Self(duration.as_millis() as u64)
56 }
57
58 pub fn from_millis(millis: u64) -> Self {
60 Self(millis)
61 }
62
63 pub fn as_millis(&self) -> u64 {
65 self.0
66 }
67}
68
69impl Default for Timestamp {
70 fn default() -> Self {
71 Self::now()
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
77#[serde(tag = "type", rename_all = "snake_case")]
78pub enum SessionType {
79 TeacherDemo {
81 task: String,
83 model: String,
85 },
86
87 Evaluation {
89 skill_name: String,
91 test_cases: usize,
93 },
94
95 Refinement {
97 skill_name: String,
99 iteration: usize,
101 },
102
103 Conversation {
105 purpose: String,
107 },
108
109 Agent {
111 agent_name: String,
113 },
114}
115
116#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
118#[serde(rename_all = "snake_case")]
119pub enum SessionStatus {
120 #[default]
122 Active,
123 Completed,
125 Failed,
127 Cancelled,
129 Paused,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SessionMetadata {
136 pub id: SessionId,
138 pub name: String,
140 pub session_type: SessionType,
142 pub created_at: Timestamp,
144 pub updated_at: Timestamp,
146 pub status: SessionStatus,
148 #[serde(default)]
150 pub tags: Vec<String>,
151 #[serde(default)]
153 pub parent_session: Option<SessionId>,
154}
155
156impl SessionMetadata {
157 pub fn new(name: impl Into<String>, session_type: SessionType) -> Self {
159 let now = Timestamp::now();
160 Self {
161 id: SessionId::new(),
162 name: name.into(),
163 session_type,
164 created_at: now,
165 updated_at: now,
166 status: SessionStatus::Active,
167 tags: Vec::new(),
168 parent_session: None,
169 }
170 }
171
172 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
174 self.tags.push(tag.into());
175 self
176 }
177
178 pub fn with_parent(mut self, parent: SessionId) -> Self {
180 self.parent_session = Some(parent);
181 self
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187#[serde(tag = "type", rename_all = "snake_case")]
188pub enum EntryType {
189 UserMessage,
191
192 AssistantMessage,
194
195 SystemMessage,
197
198 ToolCall {
200 tool_name: String,
202 success: bool,
204 },
205
206 SkillExecution {
208 skill_name: String,
210 success: bool,
212 },
213
214 EvaluationResult {
216 score: f64,
218 metrics: HashMap<String, f64>,
220 },
221
222 SystemEvent {
224 event: String,
226 },
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct SessionEntry {
232 pub id: Uuid,
234 pub timestamp: Timestamp,
236 pub entry_type: EntryType,
238 pub content: Value,
240}
241
242impl SessionEntry {
243 pub fn new(entry_type: EntryType, content: Value) -> Self {
245 Self {
246 id: Uuid::new_v4(),
247 timestamp: Timestamp::now(),
248 entry_type,
249 content,
250 }
251 }
252
253 pub fn user_message(content: impl Into<String>) -> Self {
255 Self::new(
256 EntryType::UserMessage,
257 serde_json::json!({ "text": content.into() }),
258 )
259 }
260
261 pub fn assistant_message(content: impl Into<String>) -> Self {
263 Self::new(
264 EntryType::AssistantMessage,
265 serde_json::json!({ "text": content.into() }),
266 )
267 }
268
269 pub fn tool_call(tool_name: impl Into<String>, success: bool, result: Value) -> Self {
271 Self::new(
272 EntryType::ToolCall {
273 tool_name: tool_name.into(),
274 success,
275 },
276 result,
277 )
278 }
279
280 pub fn skill_execution(skill_name: impl Into<String>, success: bool, result: Value) -> Self {
282 Self::new(
283 EntryType::SkillExecution {
284 skill_name: skill_name.into(),
285 success,
286 },
287 result,
288 )
289 }
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct Session {
295 pub metadata: SessionMetadata,
297 pub entries: Vec<SessionEntry>,
299 #[serde(default)]
301 pub context: HashMap<String, Value>,
302}
303
304impl Session {
305 pub fn new(name: impl Into<String>, session_type: SessionType) -> Self {
307 Self {
308 metadata: SessionMetadata::new(name, session_type),
309 entries: Vec::new(),
310 context: HashMap::new(),
311 }
312 }
313
314 pub fn id(&self) -> &SessionId {
316 &self.metadata.id
317 }
318
319 pub fn name(&self) -> &str {
321 &self.metadata.name
322 }
323
324 pub fn status(&self) -> SessionStatus {
326 self.metadata.status
327 }
328
329 pub fn add_entry(&mut self, entry: SessionEntry) {
331 self.entries.push(entry);
332 self.metadata.updated_at = Timestamp::now();
333 }
334
335 pub fn add_user_message(&mut self, content: impl Into<String>) {
337 self.add_entry(SessionEntry::user_message(content));
338 }
339
340 pub fn add_assistant_message(&mut self, content: impl Into<String>) {
342 self.add_entry(SessionEntry::assistant_message(content));
343 }
344
345 pub fn set_context(&mut self, key: impl Into<String>, value: Value) {
347 self.context.insert(key.into(), value);
348 self.metadata.updated_at = Timestamp::now();
349 }
350
351 pub fn get_context(&self, key: &str) -> Option<&Value> {
353 self.context.get(key)
354 }
355
356 pub fn set_status(&mut self, status: SessionStatus) {
358 self.metadata.status = status;
359 self.metadata.updated_at = Timestamp::now();
360 }
361
362 pub fn complete(&mut self) {
364 self.set_status(SessionStatus::Completed);
365 }
366
367 pub fn fail(&mut self) {
369 self.set_status(SessionStatus::Failed);
370 }
371
372 pub fn turn_count(&self) -> u32 {
377 let mut turns = 0;
378 let mut awaiting_response = false;
379
380 for entry in &self.entries {
381 match &entry.entry_type {
382 EntryType::UserMessage => {
383 awaiting_response = true;
384 }
385 EntryType::AssistantMessage => {
386 if awaiting_response {
387 turns += 1;
388 awaiting_response = false;
389 }
390 }
391 _ => {}
392 }
393 }
394
395 turns
396 }
397
398 pub fn user_message_count(&self) -> usize {
400 self.entries
401 .iter()
402 .filter(|e| matches!(e.entry_type, EntryType::UserMessage))
403 .count()
404 }
405
406 pub fn assistant_message_count(&self) -> usize {
408 self.entries
409 .iter()
410 .filter(|e| matches!(e.entry_type, EntryType::AssistantMessage))
411 .count()
412 }
413
414 pub fn duration(&self) -> Duration {
416 let start = self.metadata.created_at.as_millis();
417 let end = self.metadata.updated_at.as_millis();
418 Duration::from_millis(end.saturating_sub(start))
419 }
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize)]
424pub struct SessionConfig {
425 pub max_turns: Option<u32>,
427 pub max_entries: Option<usize>,
429 pub max_duration: Option<Duration>,
431 pub limit_action: LimitAction,
433}
434
435impl Default for SessionConfig {
436 fn default() -> Self {
437 Self {
438 max_turns: None,
439 max_entries: None,
440 max_duration: None,
441 limit_action: LimitAction::Error,
442 }
443 }
444}
445
446impl SessionConfig {
447 pub fn new() -> Self {
449 Self::default()
450 }
451
452 pub fn with_max_turns(mut self, max: u32) -> Self {
454 self.max_turns = Some(max);
455 self
456 }
457
458 pub fn with_max_entries(mut self, max: usize) -> Self {
460 self.max_entries = Some(max);
461 self
462 }
463
464 pub fn with_max_duration(mut self, max: Duration) -> Self {
466 self.max_duration = Some(max);
467 self
468 }
469
470 pub fn with_limit_action(mut self, action: LimitAction) -> Self {
472 self.limit_action = action;
473 self
474 }
475}
476
477#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
479#[serde(rename_all = "snake_case")]
480pub enum LimitAction {
481 #[default]
483 Error,
484 EndSession,
486 Ignore,
488}
489
490#[derive(Debug, Clone, PartialEq, Eq)]
492pub enum LimitCheck {
493 Ok,
495 AtLimit,
497 Exceeded(LimitExceeded),
499}
500
501#[derive(Debug, Clone, PartialEq, Eq)]
503pub enum LimitExceeded {
504 Turns { current: u32, max: u32 },
506 Entries { current: usize, max: usize },
508 Duration { current: Duration, max: Duration },
510}
511
512impl Session {
513 pub fn check_limits(&self, config: &SessionConfig) -> LimitCheck {
515 if let Some(max_turns) = config.max_turns {
517 let current = self.turn_count();
518 if current > max_turns {
519 return LimitCheck::Exceeded(LimitExceeded::Turns {
520 current,
521 max: max_turns,
522 });
523 }
524 if current == max_turns {
525 return LimitCheck::AtLimit;
526 }
527 }
528
529 if let Some(max_entries) = config.max_entries {
531 let current = self.entries.len();
532 if current > max_entries {
533 return LimitCheck::Exceeded(LimitExceeded::Entries {
534 current,
535 max: max_entries,
536 });
537 }
538 if current == max_entries {
539 return LimitCheck::AtLimit;
540 }
541 }
542
543 if let Some(max_duration) = config.max_duration {
545 let current = self.duration();
546 if current > max_duration {
547 return LimitCheck::Exceeded(LimitExceeded::Duration {
548 current,
549 max: max_duration,
550 });
551 }
552 }
553
554 LimitCheck::Ok
555 }
556
557 pub fn is_at_turn_limit(&self, config: &SessionConfig) -> bool {
559 if let Some(max_turns) = config.max_turns {
560 self.turn_count() >= max_turns
561 } else {
562 false
563 }
564 }
565
566 pub fn remaining_turns(&self, config: &SessionConfig) -> Option<u32> {
568 config
569 .max_turns
570 .map(|max| max.saturating_sub(self.turn_count()))
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577
578 #[test]
579 fn test_session_id() {
580 let id = SessionId::new();
581 let id_str = id.as_str();
582 let parsed = SessionId::from_string(&id_str).unwrap();
583 assert_eq!(id, parsed);
584 }
585
586 #[test]
587 fn test_timestamp() {
588 let ts1 = Timestamp::now();
589 std::thread::sleep(std::time::Duration::from_millis(10));
590 let ts2 = Timestamp::now();
591 assert!(ts2.as_millis() > ts1.as_millis());
592 }
593
594 #[test]
595 fn test_session_creation() {
596 let session = Session::new(
597 "Test Session",
598 SessionType::Conversation {
599 purpose: "Testing".to_string(),
600 },
601 );
602
603 assert_eq!(session.name(), "Test Session");
604 assert_eq!(session.status(), SessionStatus::Active);
605 assert_eq!(session.entries.len(), 0);
606 }
607
608 #[test]
609 fn test_session_entries() {
610 let mut session = Session::new(
611 "Test",
612 SessionType::Conversation {
613 purpose: "Test".to_string(),
614 },
615 );
616
617 session.add_user_message("Hello");
618 session.add_assistant_message("Hi there!");
619 session.add_user_message("How are you?");
620 session.add_assistant_message("I'm doing well!");
621
622 assert_eq!(session.entries.len(), 4);
623 assert_eq!(session.user_message_count(), 2);
624 assert_eq!(session.assistant_message_count(), 2);
625 }
626
627 #[test]
628 fn test_turn_counting() {
629 let mut session = Session::new(
630 "Test",
631 SessionType::Conversation {
632 purpose: "Test".to_string(),
633 },
634 );
635
636 assert_eq!(session.turn_count(), 0);
637
638 session.add_user_message("Hello");
639 assert_eq!(session.turn_count(), 0); session.add_assistant_message("Hi!");
642 assert_eq!(session.turn_count(), 1); session.add_user_message("How are you?");
645 session.add_assistant_message("Good!");
646 assert_eq!(session.turn_count(), 2); }
648
649 #[test]
650 fn test_session_limits() {
651 let mut session = Session::new(
652 "Test",
653 SessionType::Conversation {
654 purpose: "Test".to_string(),
655 },
656 );
657
658 let config = SessionConfig::new().with_max_turns(2);
659
660 assert_eq!(session.check_limits(&config), LimitCheck::Ok);
661 assert_eq!(session.remaining_turns(&config), Some(2));
662 assert!(!session.is_at_turn_limit(&config));
663
664 session.add_user_message("Hello");
666 session.add_assistant_message("Hi!");
667 assert_eq!(session.remaining_turns(&config), Some(1));
668
669 session.add_user_message("How are you?");
671 session.add_assistant_message("Good!");
672 assert_eq!(session.check_limits(&config), LimitCheck::AtLimit);
673 assert!(session.is_at_turn_limit(&config));
674 assert_eq!(session.remaining_turns(&config), Some(0));
675
676 session.add_user_message("What's up?");
678 session.add_assistant_message("Nothing much!");
679 assert!(matches!(
680 session.check_limits(&config),
681 LimitCheck::Exceeded(LimitExceeded::Turns { current: 3, max: 2 })
682 ));
683 }
684
685 #[test]
686 fn test_session_context() {
687 let mut session = Session::new(
688 "Test",
689 SessionType::Conversation {
690 purpose: "Test".to_string(),
691 },
692 );
693
694 session.set_context("key1", serde_json::json!("value1"));
695 session.set_context("key2", serde_json::json!(42));
696
697 assert_eq!(
698 session.get_context("key1"),
699 Some(&serde_json::json!("value1"))
700 );
701 assert_eq!(session.get_context("key2"), Some(&serde_json::json!(42)));
702 assert_eq!(session.get_context("key3"), None);
703 }
704
705 #[test]
706 fn test_session_status() {
707 let mut session = Session::new(
708 "Test",
709 SessionType::Conversation {
710 purpose: "Test".to_string(),
711 },
712 );
713
714 assert_eq!(session.status(), SessionStatus::Active);
715
716 session.complete();
717 assert_eq!(session.status(), SessionStatus::Completed);
718
719 session.fail();
720 assert_eq!(session.status(), SessionStatus::Failed);
721 }
722
723 #[test]
724 fn test_session_serialization() {
725 let mut session = Session::new(
726 "Test",
727 SessionType::TeacherDemo {
728 task: "Demo task".to_string(),
729 model: "gpt-4".to_string(),
730 },
731 );
732
733 session.add_user_message("Hello");
734 session.add_assistant_message("Hi!");
735 session.set_context("key", serde_json::json!("value"));
736
737 let json = serde_json::to_string(&session).unwrap();
738 let deserialized: Session = serde_json::from_str(&json).unwrap();
739
740 assert_eq!(deserialized.name(), session.name());
741 assert_eq!(deserialized.entries.len(), 2);
742 assert_eq!(
743 deserialized.get_context("key"),
744 Some(&serde_json::json!("value"))
745 );
746 }
747
748 #[test]
749 fn test_session_entry_types() {
750 let tool_entry =
751 SessionEntry::tool_call("my_tool", true, serde_json::json!({"result": "ok"}));
752 assert!(matches!(
753 tool_entry.entry_type,
754 EntryType::ToolCall { tool_name, success } if tool_name == "my_tool" && success
755 ));
756
757 let skill_entry = SessionEntry::skill_execution("my_skill", false, serde_json::json!({}));
758 assert!(matches!(
759 skill_entry.entry_type,
760 EntryType::SkillExecution { skill_name, success } if skill_name == "my_skill" && !success
761 ));
762 }
763}