Skip to main content

thulp_workspace/
session.rs

1//! Session types and management for thulp.
2//!
3//! This module provides session tracking for conversation history,
4//! evaluation runs, and skill execution sessions.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10use uuid::Uuid;
11
12/// Unique session identifier.
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct SessionId(pub Uuid);
15
16impl SessionId {
17    /// Create a new random session ID.
18    pub fn new() -> Self {
19        Self(Uuid::new_v4())
20    }
21
22    /// Parse a session ID from a string.
23    pub fn from_string(s: &str) -> Result<Self, uuid::Error> {
24        Ok(Self(Uuid::parse_str(s)?))
25    }
26
27    /// Get the UUID as a string.
28    pub fn as_str(&self) -> String {
29        self.0.to_string()
30    }
31}
32
33impl Default for SessionId {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl std::fmt::Display for SessionId {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45/// Timestamp in milliseconds since Unix epoch.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
47pub struct Timestamp(pub u64);
48
49impl Timestamp {
50    /// Create a timestamp for the current time.
51    pub fn now() -> Self {
52        let duration = SystemTime::now()
53            .duration_since(UNIX_EPOCH)
54            .unwrap_or(Duration::ZERO);
55        Self(duration.as_millis() as u64)
56    }
57
58    /// Create a timestamp from milliseconds.
59    pub fn from_millis(millis: u64) -> Self {
60        Self(millis)
61    }
62
63    /// Get the timestamp as milliseconds.
64    pub fn as_millis(&self) -> u64 {
65        self.0
66    }
67}
68
69impl Default for Timestamp {
70    fn default() -> Self {
71        Self::now()
72    }
73}
74
75/// Type of session.
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
77#[serde(tag = "type", rename_all = "snake_case")]
78pub enum SessionType {
79    /// Teacher demonstration session (for distillation).
80    TeacherDemo {
81        /// Task being demonstrated.
82        task: String,
83        /// Model used for demonstration.
84        model: String,
85    },
86
87    /// Skill evaluation run.
88    Evaluation {
89        /// Name of the skill being evaluated.
90        skill_name: String,
91        /// Number of test cases.
92        test_cases: usize,
93    },
94
95    /// Skill refinement session.
96    Refinement {
97        /// Name of the skill being refined.
98        skill_name: String,
99        /// Iteration number.
100        iteration: usize,
101    },
102
103    /// Generic conversation session.
104    Conversation {
105        /// Purpose of the conversation.
106        purpose: String,
107    },
108
109    /// Agent interaction session.
110    Agent {
111        /// Agent name or identifier.
112        agent_name: String,
113    },
114}
115
116/// Status of a session.
117#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
118#[serde(rename_all = "snake_case")]
119pub enum SessionStatus {
120    /// Session is currently active.
121    #[default]
122    Active,
123    /// Session completed successfully.
124    Completed,
125    /// Session failed with an error.
126    Failed,
127    /// Session was cancelled.
128    Cancelled,
129    /// Session is paused.
130    Paused,
131}
132
133/// Session metadata.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SessionMetadata {
136    /// Unique session identifier.
137    pub id: SessionId,
138    /// Human-readable session name.
139    pub name: String,
140    /// Type of session.
141    pub session_type: SessionType,
142    /// When the session was created.
143    pub created_at: Timestamp,
144    /// When the session was last updated.
145    pub updated_at: Timestamp,
146    /// Current session status.
147    pub status: SessionStatus,
148    /// Tags for categorization.
149    #[serde(default)]
150    pub tags: Vec<String>,
151    /// Parent session ID (for linked sessions).
152    #[serde(default)]
153    pub parent_session: Option<SessionId>,
154}
155
156impl SessionMetadata {
157    /// Create new session metadata.
158    pub fn new(name: impl Into<String>, session_type: SessionType) -> Self {
159        let now = Timestamp::now();
160        Self {
161            id: SessionId::new(),
162            name: name.into(),
163            session_type,
164            created_at: now,
165            updated_at: now,
166            status: SessionStatus::Active,
167            tags: Vec::new(),
168            parent_session: None,
169        }
170    }
171
172    /// Add a tag.
173    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
174        self.tags.push(tag.into());
175        self
176    }
177
178    /// Set parent session.
179    pub fn with_parent(mut self, parent: SessionId) -> Self {
180        self.parent_session = Some(parent);
181        self
182    }
183}
184
185/// Type of entry in a session.
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187#[serde(tag = "type", rename_all = "snake_case")]
188pub enum EntryType {
189    /// User message.
190    UserMessage,
191
192    /// Assistant/AI response.
193    AssistantMessage,
194
195    /// System message.
196    SystemMessage,
197
198    /// Tool invocation.
199    ToolCall {
200        /// Name of the tool called.
201        tool_name: String,
202        /// Whether the call succeeded.
203        success: bool,
204    },
205
206    /// Skill execution.
207    SkillExecution {
208        /// Name of the skill.
209        skill_name: String,
210        /// Whether execution succeeded.
211        success: bool,
212    },
213
214    /// Evaluation result.
215    EvaluationResult {
216        /// Overall score (0.0 - 1.0).
217        score: f64,
218        /// Detailed metrics.
219        metrics: HashMap<String, f64>,
220    },
221
222    /// System event (logging, state changes, etc.).
223    SystemEvent {
224        /// Event type/name.
225        event: String,
226    },
227}
228
229/// A single entry in a session.
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct SessionEntry {
232    /// Unique entry identifier.
233    pub id: Uuid,
234    /// When the entry was created.
235    pub timestamp: Timestamp,
236    /// Type of entry.
237    pub entry_type: EntryType,
238    /// Entry content/data.
239    pub content: Value,
240}
241
242impl SessionEntry {
243    /// Create a new session entry.
244    pub fn new(entry_type: EntryType, content: Value) -> Self {
245        Self {
246            id: Uuid::new_v4(),
247            timestamp: Timestamp::now(),
248            entry_type,
249            content,
250        }
251    }
252
253    /// Create a user message entry.
254    pub fn user_message(content: impl Into<String>) -> Self {
255        Self::new(
256            EntryType::UserMessage,
257            serde_json::json!({ "text": content.into() }),
258        )
259    }
260
261    /// Create an assistant message entry.
262    pub fn assistant_message(content: impl Into<String>) -> Self {
263        Self::new(
264            EntryType::AssistantMessage,
265            serde_json::json!({ "text": content.into() }),
266        )
267    }
268
269    /// Create a tool call entry.
270    pub fn tool_call(tool_name: impl Into<String>, success: bool, result: Value) -> Self {
271        Self::new(
272            EntryType::ToolCall {
273                tool_name: tool_name.into(),
274                success,
275            },
276            result,
277        )
278    }
279
280    /// Create a skill execution entry.
281    pub fn skill_execution(skill_name: impl Into<String>, success: bool, result: Value) -> Self {
282        Self::new(
283            EntryType::SkillExecution {
284                skill_name: skill_name.into(),
285                success,
286            },
287            result,
288        )
289    }
290}
291
292/// Complete session data.
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct Session {
295    /// Session metadata.
296    pub metadata: SessionMetadata,
297    /// Session entries (conversation history, tool calls, etc.).
298    pub entries: Vec<SessionEntry>,
299    /// Session context data (key-value store).
300    #[serde(default)]
301    pub context: HashMap<String, Value>,
302}
303
304impl Session {
305    /// Create a new session.
306    pub fn new(name: impl Into<String>, session_type: SessionType) -> Self {
307        Self {
308            metadata: SessionMetadata::new(name, session_type),
309            entries: Vec::new(),
310            context: HashMap::new(),
311        }
312    }
313
314    /// Get the session ID.
315    pub fn id(&self) -> &SessionId {
316        &self.metadata.id
317    }
318
319    /// Get the session name.
320    pub fn name(&self) -> &str {
321        &self.metadata.name
322    }
323
324    /// Get the session status.
325    pub fn status(&self) -> SessionStatus {
326        self.metadata.status
327    }
328
329    /// Add an entry to the session.
330    pub fn add_entry(&mut self, entry: SessionEntry) {
331        self.entries.push(entry);
332        self.metadata.updated_at = Timestamp::now();
333    }
334
335    /// Add a user message.
336    pub fn add_user_message(&mut self, content: impl Into<String>) {
337        self.add_entry(SessionEntry::user_message(content));
338    }
339
340    /// Add an assistant message.
341    pub fn add_assistant_message(&mut self, content: impl Into<String>) {
342        self.add_entry(SessionEntry::assistant_message(content));
343    }
344
345    /// Set context value.
346    pub fn set_context(&mut self, key: impl Into<String>, value: Value) {
347        self.context.insert(key.into(), value);
348        self.metadata.updated_at = Timestamp::now();
349    }
350
351    /// Get context value.
352    pub fn get_context(&self, key: &str) -> Option<&Value> {
353        self.context.get(key)
354    }
355
356    /// Update session status.
357    pub fn set_status(&mut self, status: SessionStatus) {
358        self.metadata.status = status;
359        self.metadata.updated_at = Timestamp::now();
360    }
361
362    /// Mark session as completed.
363    pub fn complete(&mut self) {
364        self.set_status(SessionStatus::Completed);
365    }
366
367    /// Mark session as failed.
368    pub fn fail(&mut self) {
369        self.set_status(SessionStatus::Failed);
370    }
371
372    /// Count the number of turns in the session.
373    ///
374    /// A turn is defined as a user message followed by an assistant message.
375    /// This counts the number of complete turns.
376    pub fn turn_count(&self) -> u32 {
377        let mut turns = 0;
378        let mut awaiting_response = false;
379
380        for entry in &self.entries {
381            match &entry.entry_type {
382                EntryType::UserMessage => {
383                    awaiting_response = true;
384                }
385                EntryType::AssistantMessage => {
386                    if awaiting_response {
387                        turns += 1;
388                        awaiting_response = false;
389                    }
390                }
391                _ => {}
392            }
393        }
394
395        turns
396    }
397
398    /// Count user messages in the session.
399    pub fn user_message_count(&self) -> usize {
400        self.entries
401            .iter()
402            .filter(|e| matches!(e.entry_type, EntryType::UserMessage))
403            .count()
404    }
405
406    /// Count assistant messages in the session.
407    pub fn assistant_message_count(&self) -> usize {
408        self.entries
409            .iter()
410            .filter(|e| matches!(e.entry_type, EntryType::AssistantMessage))
411            .count()
412    }
413
414    /// Get the duration of the session.
415    pub fn duration(&self) -> Duration {
416        let start = self.metadata.created_at.as_millis();
417        let end = self.metadata.updated_at.as_millis();
418        Duration::from_millis(end.saturating_sub(start))
419    }
420}
421
422/// Configuration for session limits.
423#[derive(Debug, Clone, Serialize, Deserialize)]
424pub struct SessionConfig {
425    /// Maximum number of turns allowed (None = unlimited).
426    pub max_turns: Option<u32>,
427    /// Maximum number of entries allowed (None = unlimited).
428    pub max_entries: Option<usize>,
429    /// Maximum session duration (None = unlimited).
430    pub max_duration: Option<Duration>,
431    /// Action to take when limit is reached.
432    pub limit_action: LimitAction,
433}
434
435impl Default for SessionConfig {
436    fn default() -> Self {
437        Self {
438            max_turns: None,
439            max_entries: None,
440            max_duration: None,
441            limit_action: LimitAction::Error,
442        }
443    }
444}
445
446impl SessionConfig {
447    /// Create a new session config with default values.
448    pub fn new() -> Self {
449        Self::default()
450    }
451
452    /// Set maximum turns.
453    pub fn with_max_turns(mut self, max: u32) -> Self {
454        self.max_turns = Some(max);
455        self
456    }
457
458    /// Set maximum entries.
459    pub fn with_max_entries(mut self, max: usize) -> Self {
460        self.max_entries = Some(max);
461        self
462    }
463
464    /// Set maximum duration.
465    pub fn with_max_duration(mut self, max: Duration) -> Self {
466        self.max_duration = Some(max);
467        self
468    }
469
470    /// Set limit action.
471    pub fn with_limit_action(mut self, action: LimitAction) -> Self {
472        self.limit_action = action;
473        self
474    }
475}
476
477/// Action to take when a session limit is reached.
478#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
479#[serde(rename_all = "snake_case")]
480pub enum LimitAction {
481    /// Return an error.
482    #[default]
483    Error,
484    /// End the session gracefully.
485    EndSession,
486    /// Ignore and continue (for soft limits).
487    Ignore,
488}
489
490/// Check result for session limits.
491#[derive(Debug, Clone, PartialEq, Eq)]
492pub enum LimitCheck {
493    /// Within limits.
494    Ok,
495    /// At the limit but not exceeded.
496    AtLimit,
497    /// Limit exceeded.
498    Exceeded(LimitExceeded),
499}
500
501/// Which limit was exceeded.
502#[derive(Debug, Clone, PartialEq, Eq)]
503pub enum LimitExceeded {
504    /// Turn limit exceeded.
505    Turns { current: u32, max: u32 },
506    /// Entry limit exceeded.
507    Entries { current: usize, max: usize },
508    /// Duration limit exceeded.
509    Duration { current: Duration, max: Duration },
510}
511
512impl Session {
513    /// Check if the session is within limits.
514    pub fn check_limits(&self, config: &SessionConfig) -> LimitCheck {
515        // Check turn limit
516        if let Some(max_turns) = config.max_turns {
517            let current = self.turn_count();
518            if current > max_turns {
519                return LimitCheck::Exceeded(LimitExceeded::Turns {
520                    current,
521                    max: max_turns,
522                });
523            }
524            if current == max_turns {
525                return LimitCheck::AtLimit;
526            }
527        }
528
529        // Check entry limit
530        if let Some(max_entries) = config.max_entries {
531            let current = self.entries.len();
532            if current > max_entries {
533                return LimitCheck::Exceeded(LimitExceeded::Entries {
534                    current,
535                    max: max_entries,
536                });
537            }
538            if current == max_entries {
539                return LimitCheck::AtLimit;
540            }
541        }
542
543        // Check duration limit
544        if let Some(max_duration) = config.max_duration {
545            let current = self.duration();
546            if current > max_duration {
547                return LimitCheck::Exceeded(LimitExceeded::Duration {
548                    current,
549                    max: max_duration,
550                });
551            }
552        }
553
554        LimitCheck::Ok
555    }
556
557    /// Check if the session is at its turn limit.
558    pub fn is_at_turn_limit(&self, config: &SessionConfig) -> bool {
559        if let Some(max_turns) = config.max_turns {
560            self.turn_count() >= max_turns
561        } else {
562            false
563        }
564    }
565
566    /// Get remaining turns before limit.
567    pub fn remaining_turns(&self, config: &SessionConfig) -> Option<u32> {
568        config
569            .max_turns
570            .map(|max| max.saturating_sub(self.turn_count()))
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[test]
579    fn test_session_id() {
580        let id = SessionId::new();
581        let id_str = id.as_str();
582        let parsed = SessionId::from_string(&id_str).unwrap();
583        assert_eq!(id, parsed);
584    }
585
586    #[test]
587    fn test_timestamp() {
588        let ts1 = Timestamp::now();
589        std::thread::sleep(std::time::Duration::from_millis(10));
590        let ts2 = Timestamp::now();
591        assert!(ts2.as_millis() > ts1.as_millis());
592    }
593
594    #[test]
595    fn test_session_creation() {
596        let session = Session::new(
597            "Test Session",
598            SessionType::Conversation {
599                purpose: "Testing".to_string(),
600            },
601        );
602
603        assert_eq!(session.name(), "Test Session");
604        assert_eq!(session.status(), SessionStatus::Active);
605        assert_eq!(session.entries.len(), 0);
606    }
607
608    #[test]
609    fn test_session_entries() {
610        let mut session = Session::new(
611            "Test",
612            SessionType::Conversation {
613                purpose: "Test".to_string(),
614            },
615        );
616
617        session.add_user_message("Hello");
618        session.add_assistant_message("Hi there!");
619        session.add_user_message("How are you?");
620        session.add_assistant_message("I'm doing well!");
621
622        assert_eq!(session.entries.len(), 4);
623        assert_eq!(session.user_message_count(), 2);
624        assert_eq!(session.assistant_message_count(), 2);
625    }
626
627    #[test]
628    fn test_turn_counting() {
629        let mut session = Session::new(
630            "Test",
631            SessionType::Conversation {
632                purpose: "Test".to_string(),
633            },
634        );
635
636        assert_eq!(session.turn_count(), 0);
637
638        session.add_user_message("Hello");
639        assert_eq!(session.turn_count(), 0); // No response yet
640
641        session.add_assistant_message("Hi!");
642        assert_eq!(session.turn_count(), 1); // First complete turn
643
644        session.add_user_message("How are you?");
645        session.add_assistant_message("Good!");
646        assert_eq!(session.turn_count(), 2); // Second complete turn
647    }
648
649    #[test]
650    fn test_session_limits() {
651        let mut session = Session::new(
652            "Test",
653            SessionType::Conversation {
654                purpose: "Test".to_string(),
655            },
656        );
657
658        let config = SessionConfig::new().with_max_turns(2);
659
660        assert_eq!(session.check_limits(&config), LimitCheck::Ok);
661        assert_eq!(session.remaining_turns(&config), Some(2));
662        assert!(!session.is_at_turn_limit(&config));
663
664        // First turn
665        session.add_user_message("Hello");
666        session.add_assistant_message("Hi!");
667        assert_eq!(session.remaining_turns(&config), Some(1));
668
669        // Second turn (at limit)
670        session.add_user_message("How are you?");
671        session.add_assistant_message("Good!");
672        assert_eq!(session.check_limits(&config), LimitCheck::AtLimit);
673        assert!(session.is_at_turn_limit(&config));
674        assert_eq!(session.remaining_turns(&config), Some(0));
675
676        // Third turn (exceeded)
677        session.add_user_message("What's up?");
678        session.add_assistant_message("Nothing much!");
679        assert!(matches!(
680            session.check_limits(&config),
681            LimitCheck::Exceeded(LimitExceeded::Turns { current: 3, max: 2 })
682        ));
683    }
684
685    #[test]
686    fn test_session_context() {
687        let mut session = Session::new(
688            "Test",
689            SessionType::Conversation {
690                purpose: "Test".to_string(),
691            },
692        );
693
694        session.set_context("key1", serde_json::json!("value1"));
695        session.set_context("key2", serde_json::json!(42));
696
697        assert_eq!(
698            session.get_context("key1"),
699            Some(&serde_json::json!("value1"))
700        );
701        assert_eq!(session.get_context("key2"), Some(&serde_json::json!(42)));
702        assert_eq!(session.get_context("key3"), None);
703    }
704
705    #[test]
706    fn test_session_status() {
707        let mut session = Session::new(
708            "Test",
709            SessionType::Conversation {
710                purpose: "Test".to_string(),
711            },
712        );
713
714        assert_eq!(session.status(), SessionStatus::Active);
715
716        session.complete();
717        assert_eq!(session.status(), SessionStatus::Completed);
718
719        session.fail();
720        assert_eq!(session.status(), SessionStatus::Failed);
721    }
722
723    #[test]
724    fn test_session_serialization() {
725        let mut session = Session::new(
726            "Test",
727            SessionType::TeacherDemo {
728                task: "Demo task".to_string(),
729                model: "gpt-4".to_string(),
730            },
731        );
732
733        session.add_user_message("Hello");
734        session.add_assistant_message("Hi!");
735        session.set_context("key", serde_json::json!("value"));
736
737        let json = serde_json::to_string(&session).unwrap();
738        let deserialized: Session = serde_json::from_str(&json).unwrap();
739
740        assert_eq!(deserialized.name(), session.name());
741        assert_eq!(deserialized.entries.len(), 2);
742        assert_eq!(
743            deserialized.get_context("key"),
744            Some(&serde_json::json!("value"))
745        );
746    }
747
748    #[test]
749    fn test_session_entry_types() {
750        let tool_entry =
751            SessionEntry::tool_call("my_tool", true, serde_json::json!({"result": "ok"}));
752        assert!(matches!(
753            tool_entry.entry_type,
754            EntryType::ToolCall { tool_name, success } if tool_name == "my_tool" && success
755        ));
756
757        let skill_entry = SessionEntry::skill_execution("my_skill", false, serde_json::json!({}));
758        assert!(matches!(
759            skill_entry.entry_type,
760            EntryType::SkillExecution { skill_name, success } if skill_name == "my_skill" && !success
761        ));
762    }
763}