1use std::sync::Arc;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicU8, Ordering};
12
13use crate::router::Extensions;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18#[non_exhaustive]
19pub enum SessionPhase {
20 Uninitialized = 0,
22 Initializing = 1,
24 Initialized = 2,
26}
27
28impl From<u8> for SessionPhase {
29 fn from(value: u8) -> Self {
30 match value {
31 0 => SessionPhase::Uninitialized,
32 1 => SessionPhase::Initializing,
33 2 => SessionPhase::Initialized,
34 _ => SessionPhase::Uninitialized,
35 }
36 }
37}
38
39#[derive(Clone)]
69pub struct SessionState {
70 phase: Arc<AtomicU8>,
71 extensions: Arc<RwLock<Extensions>>,
72}
73
74impl Default for SessionState {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl SessionState {
81 pub fn new() -> Self {
83 Self {
84 phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
85 extensions: Arc::new(RwLock::new(Extensions::new())),
86 }
87 }
88
89 pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) {
104 if let Ok(mut ext) = self.extensions.write() {
105 ext.insert(val);
106 }
107 }
108
109 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
125 self.extensions
126 .read()
127 .ok()
128 .and_then(|ext| ext.get::<T>().cloned())
129 }
130
131 pub fn phase(&self) -> SessionPhase {
133 SessionPhase::from(self.phase.load(Ordering::Acquire))
134 }
135
136 pub fn is_initialized(&self) -> bool {
138 self.phase() == SessionPhase::Initialized
139 }
140
141 pub fn mark_initializing(&self) -> bool {
145 self.phase
146 .compare_exchange(
147 SessionPhase::Uninitialized as u8,
148 SessionPhase::Initializing as u8,
149 Ordering::AcqRel,
150 Ordering::Acquire,
151 )
152 .is_ok()
153 }
154
155 pub fn mark_initialized(&self) -> bool {
166 if self
168 .phase
169 .compare_exchange(
170 SessionPhase::Initializing as u8,
171 SessionPhase::Initialized as u8,
172 Ordering::AcqRel,
173 Ordering::Acquire,
174 )
175 .is_ok()
176 {
177 return true;
178 }
179
180 self.phase
184 .compare_exchange(
185 SessionPhase::Uninitialized as u8,
186 SessionPhase::Initialized as u8,
187 Ordering::AcqRel,
188 Ordering::Acquire,
189 )
190 .is_ok()
191 }
192
193 pub fn is_request_allowed(&self, method: &str) -> bool {
198 match self.phase() {
199 SessionPhase::Uninitialized => {
200 matches!(method, "initialize" | "ping")
201 }
202 SessionPhase::Initializing | SessionPhase::Initialized => true,
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_session_lifecycle() {
213 let session = SessionState::new();
214
215 assert_eq!(session.phase(), SessionPhase::Uninitialized);
217 assert!(!session.is_initialized());
218
219 assert!(session.is_request_allowed("initialize"));
221 assert!(session.is_request_allowed("ping"));
222 assert!(!session.is_request_allowed("tools/list"));
223
224 assert!(session.mark_initializing());
226 assert_eq!(session.phase(), SessionPhase::Initializing);
227 assert!(!session.is_initialized());
228
229 assert!(!session.mark_initializing());
231
232 assert!(session.is_request_allowed("tools/list"));
234
235 assert!(session.mark_initialized());
237 assert_eq!(session.phase(), SessionPhase::Initialized);
238 assert!(session.is_initialized());
239
240 assert!(!session.mark_initialized());
242 }
243
244 #[test]
245 fn test_session_clone_shares_state() {
246 let session1 = SessionState::new();
247 let session2 = session1.clone();
248
249 session1.mark_initializing();
250 assert_eq!(session2.phase(), SessionPhase::Initializing);
251
252 session2.mark_initialized();
253 assert_eq!(session1.phase(), SessionPhase::Initialized);
254 }
255
256 #[test]
257 fn test_session_extensions_insert_and_get() {
258 let session = SessionState::new();
259
260 session.insert(42u32);
262 assert_eq!(session.get::<u32>(), Some(42));
263
264 assert_eq!(session.get::<String>(), None);
266 }
267
268 #[test]
269 fn test_session_extensions_overwrite() {
270 let session = SessionState::new();
271
272 session.insert(42u32);
273 assert_eq!(session.get::<u32>(), Some(42));
274
275 session.insert(100u32);
277 assert_eq!(session.get::<u32>(), Some(100));
278 }
279
280 #[test]
281 fn test_session_extensions_multiple_types() {
282 let session = SessionState::new();
283
284 session.insert(42u32);
285 session.insert("hello".to_string());
286 session.insert(true);
287
288 assert_eq!(session.get::<u32>(), Some(42));
289 assert_eq!(session.get::<String>(), Some("hello".to_string()));
290 assert_eq!(session.get::<bool>(), Some(true));
291 }
292
293 #[test]
294 fn test_session_extensions_shared_across_clones() {
295 let session1 = SessionState::new();
296 let session2 = session1.clone();
297
298 session1.insert(42u32);
300
301 assert_eq!(session2.get::<u32>(), Some(42));
303
304 session2.insert("world".to_string());
306
307 assert_eq!(session1.get::<String>(), Some("world".to_string()));
309 }
310
311 #[test]
312 fn test_mark_initialized_from_uninitialized() {
313 let session = SessionState::new();
314
315 assert_eq!(session.phase(), SessionPhase::Uninitialized);
319 assert!(session.mark_initialized());
320 assert_eq!(session.phase(), SessionPhase::Initialized);
321 assert!(session.is_initialized());
322
323 assert!(session.is_request_allowed("tools/list"));
325 assert!(session.is_request_allowed("ping"));
326 }
327
328 #[test]
329 fn test_mark_initialized_idempotent_when_already_initialized() {
330 let session = SessionState::new();
331
332 session.mark_initializing();
334 session.mark_initialized();
335 assert_eq!(session.phase(), SessionPhase::Initialized);
336
337 assert!(!session.mark_initialized());
339 assert_eq!(session.phase(), SessionPhase::Initialized);
340 }
341
342 #[test]
343 fn test_session_extensions_custom_type() {
344 #[derive(Debug, Clone, PartialEq)]
345 struct UserClaims {
346 user_id: String,
347 role: String,
348 }
349
350 let session = SessionState::new();
351
352 session.insert(UserClaims {
353 user_id: "user123".to_string(),
354 role: "admin".to_string(),
355 });
356
357 let claims = session.get::<UserClaims>();
358 assert!(claims.is_some());
359 let claims = claims.unwrap();
360 assert_eq!(claims.user_id, "user123");
361 assert_eq!(claims.role, "admin");
362 }
363}