1use std::sync::Arc;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicU8, Ordering};
12
13use crate::router::Extensions;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum SessionPhase {
19 Uninitialized = 0,
21 Initializing = 1,
23 Initialized = 2,
25}
26
27impl From<u8> for SessionPhase {
28 fn from(value: u8) -> Self {
29 match value {
30 0 => SessionPhase::Uninitialized,
31 1 => SessionPhase::Initializing,
32 2 => SessionPhase::Initialized,
33 _ => SessionPhase::Uninitialized,
34 }
35 }
36}
37
38#[derive(Clone)]
68pub struct SessionState {
69 phase: Arc<AtomicU8>,
70 extensions: Arc<RwLock<Extensions>>,
71}
72
73impl Default for SessionState {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl SessionState {
80 pub fn new() -> Self {
82 Self {
83 phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
84 extensions: Arc::new(RwLock::new(Extensions::new())),
85 }
86 }
87
88 pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) {
103 if let Ok(mut ext) = self.extensions.write() {
104 ext.insert(val);
105 }
106 }
107
108 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
124 self.extensions
125 .read()
126 .ok()
127 .and_then(|ext| ext.get::<T>().cloned())
128 }
129
130 pub fn phase(&self) -> SessionPhase {
132 SessionPhase::from(self.phase.load(Ordering::Acquire))
133 }
134
135 pub fn is_initialized(&self) -> bool {
137 self.phase() == SessionPhase::Initialized
138 }
139
140 pub fn mark_initializing(&self) -> bool {
144 self.phase
145 .compare_exchange(
146 SessionPhase::Uninitialized as u8,
147 SessionPhase::Initializing as u8,
148 Ordering::AcqRel,
149 Ordering::Acquire,
150 )
151 .is_ok()
152 }
153
154 pub fn mark_initialized(&self) -> bool {
158 self.phase
159 .compare_exchange(
160 SessionPhase::Initializing as u8,
161 SessionPhase::Initialized as u8,
162 Ordering::AcqRel,
163 Ordering::Acquire,
164 )
165 .is_ok()
166 }
167
168 pub fn is_request_allowed(&self, method: &str) -> bool {
173 match self.phase() {
174 SessionPhase::Uninitialized => {
175 matches!(method, "initialize" | "ping")
176 }
177 SessionPhase::Initializing | SessionPhase::Initialized => true,
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_session_lifecycle() {
188 let session = SessionState::new();
189
190 assert_eq!(session.phase(), SessionPhase::Uninitialized);
192 assert!(!session.is_initialized());
193
194 assert!(session.is_request_allowed("initialize"));
196 assert!(session.is_request_allowed("ping"));
197 assert!(!session.is_request_allowed("tools/list"));
198
199 assert!(session.mark_initializing());
201 assert_eq!(session.phase(), SessionPhase::Initializing);
202 assert!(!session.is_initialized());
203
204 assert!(!session.mark_initializing());
206
207 assert!(session.is_request_allowed("tools/list"));
209
210 assert!(session.mark_initialized());
212 assert_eq!(session.phase(), SessionPhase::Initialized);
213 assert!(session.is_initialized());
214
215 assert!(!session.mark_initialized());
217 }
218
219 #[test]
220 fn test_session_clone_shares_state() {
221 let session1 = SessionState::new();
222 let session2 = session1.clone();
223
224 session1.mark_initializing();
225 assert_eq!(session2.phase(), SessionPhase::Initializing);
226
227 session2.mark_initialized();
228 assert_eq!(session1.phase(), SessionPhase::Initialized);
229 }
230
231 #[test]
232 fn test_session_extensions_insert_and_get() {
233 let session = SessionState::new();
234
235 session.insert(42u32);
237 assert_eq!(session.get::<u32>(), Some(42));
238
239 assert_eq!(session.get::<String>(), None);
241 }
242
243 #[test]
244 fn test_session_extensions_overwrite() {
245 let session = SessionState::new();
246
247 session.insert(42u32);
248 assert_eq!(session.get::<u32>(), Some(42));
249
250 session.insert(100u32);
252 assert_eq!(session.get::<u32>(), Some(100));
253 }
254
255 #[test]
256 fn test_session_extensions_multiple_types() {
257 let session = SessionState::new();
258
259 session.insert(42u32);
260 session.insert("hello".to_string());
261 session.insert(true);
262
263 assert_eq!(session.get::<u32>(), Some(42));
264 assert_eq!(session.get::<String>(), Some("hello".to_string()));
265 assert_eq!(session.get::<bool>(), Some(true));
266 }
267
268 #[test]
269 fn test_session_extensions_shared_across_clones() {
270 let session1 = SessionState::new();
271 let session2 = session1.clone();
272
273 session1.insert(42u32);
275
276 assert_eq!(session2.get::<u32>(), Some(42));
278
279 session2.insert("world".to_string());
281
282 assert_eq!(session1.get::<String>(), Some("world".to_string()));
284 }
285
286 #[test]
287 fn test_session_extensions_custom_type() {
288 #[derive(Debug, Clone, PartialEq)]
289 struct UserClaims {
290 user_id: String,
291 role: String,
292 }
293
294 let session = SessionState::new();
295
296 session.insert(UserClaims {
297 user_id: "user123".to_string(),
298 role: "admin".to_string(),
299 });
300
301 let claims = session.get::<UserClaims>();
302 assert!(claims.is_some());
303 let claims = claims.unwrap();
304 assert_eq!(claims.user_id, "user123");
305 assert_eq!(claims.role, "admin");
306 }
307}