1use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, Ordering};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u8)]
12pub enum SessionPhase {
13 Uninitialized = 0,
15 Initializing = 1,
17 Initialized = 2,
19}
20
21impl From<u8> for SessionPhase {
22 fn from(value: u8) -> Self {
23 match value {
24 0 => SessionPhase::Uninitialized,
25 1 => SessionPhase::Initializing,
26 2 => SessionPhase::Initialized,
27 _ => SessionPhase::Uninitialized,
28 }
29 }
30}
31
32#[derive(Clone)]
36pub struct SessionState {
37 phase: Arc<AtomicU8>,
38}
39
40impl Default for SessionState {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SessionState {
47 pub fn new() -> Self {
49 Self {
50 phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
51 }
52 }
53
54 pub fn phase(&self) -> SessionPhase {
56 SessionPhase::from(self.phase.load(Ordering::Acquire))
57 }
58
59 pub fn is_initialized(&self) -> bool {
61 self.phase() == SessionPhase::Initialized
62 }
63
64 pub fn mark_initializing(&self) -> bool {
68 self.phase
69 .compare_exchange(
70 SessionPhase::Uninitialized as u8,
71 SessionPhase::Initializing as u8,
72 Ordering::AcqRel,
73 Ordering::Acquire,
74 )
75 .is_ok()
76 }
77
78 pub fn mark_initialized(&self) -> bool {
82 self.phase
83 .compare_exchange(
84 SessionPhase::Initializing as u8,
85 SessionPhase::Initialized as u8,
86 Ordering::AcqRel,
87 Ordering::Acquire,
88 )
89 .is_ok()
90 }
91
92 pub fn is_request_allowed(&self, method: &str) -> bool {
97 match self.phase() {
98 SessionPhase::Uninitialized => {
99 matches!(method, "initialize" | "ping")
100 }
101 SessionPhase::Initializing | SessionPhase::Initialized => true,
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_session_lifecycle() {
112 let session = SessionState::new();
113
114 assert_eq!(session.phase(), SessionPhase::Uninitialized);
116 assert!(!session.is_initialized());
117
118 assert!(session.is_request_allowed("initialize"));
120 assert!(session.is_request_allowed("ping"));
121 assert!(!session.is_request_allowed("tools/list"));
122
123 assert!(session.mark_initializing());
125 assert_eq!(session.phase(), SessionPhase::Initializing);
126 assert!(!session.is_initialized());
127
128 assert!(!session.mark_initializing());
130
131 assert!(session.is_request_allowed("tools/list"));
133
134 assert!(session.mark_initialized());
136 assert_eq!(session.phase(), SessionPhase::Initialized);
137 assert!(session.is_initialized());
138
139 assert!(!session.mark_initialized());
141 }
142
143 #[test]
144 fn test_session_clone_shares_state() {
145 let session1 = SessionState::new();
146 let session2 = session1.clone();
147
148 session1.mark_initializing();
149 assert_eq!(session2.phase(), SessionPhase::Initializing);
150
151 session2.mark_initialized();
152 assert_eq!(session1.phase(), SessionPhase::Initialized);
153 }
154}