Skip to main content

tower_mcp/
session.rs

1//! MCP session state management
2//!
3//! Tracks the lifecycle state of an MCP connection as per the specification.
4//! The session progresses through phases: Uninitialized -> Initializing -> Initialized.
5//!
6//! Sessions also support type-safe extensions for storing arbitrary data like
7//! authentication claims, user roles, or other session-scoped state.
8
9use std::sync::Arc;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicU8, Ordering};
12
13use crate::router::Extensions;
14
15/// Session lifecycle phase
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18#[non_exhaustive]
19pub enum SessionPhase {
20    /// Initial state - only `initialize` and `ping` requests are valid
21    Uninitialized = 0,
22    /// Server has responded to `initialize`, waiting for `initialized` notification
23    Initializing = 1,
24    /// `initialized` notification received, normal operation
25    Initialized = 2,
26}
27
28impl From<u8> for SessionPhase {
29    fn from(value: u8) -> Self {
30        match value {
31            0 => SessionPhase::Uninitialized,
32            1 => SessionPhase::Initializing,
33            2 => SessionPhase::Initialized,
34            _ => SessionPhase::Uninitialized,
35        }
36    }
37}
38
39/// Shared session state that can be cloned across requests.
40///
41/// Uses atomic operations for thread-safe state transitions. Includes a type-safe
42/// extensions map for storing session-scoped data like authentication claims.
43///
44/// # Example
45///
46/// ```rust
47/// use tower_mcp::SessionState;
48///
49/// #[derive(Debug, Clone)]
50/// struct UserClaims {
51///     user_id: String,
52///     role: String,
53/// }
54///
55/// let session = SessionState::new();
56///
57/// // Store auth claims in the session
58/// session.insert(UserClaims {
59///     user_id: "user123".to_string(),
60///     role: "admin".to_string(),
61/// });
62///
63/// // Retrieve claims later
64/// if let Some(claims) = session.get::<UserClaims>() {
65///     assert_eq!(claims.role, "admin");
66/// }
67/// ```
68#[derive(Clone)]
69pub struct SessionState {
70    phase: Arc<AtomicU8>,
71    extensions: Arc<RwLock<Extensions>>,
72}
73
74impl Default for SessionState {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl SessionState {
81    /// Create a new session in the Uninitialized phase
82    pub fn new() -> Self {
83        Self {
84            phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
85            extensions: Arc::new(RwLock::new(Extensions::new())),
86        }
87    }
88
89    /// Insert a value into the session extensions.
90    ///
91    /// This is typically used by auth middleware to store claims that can
92    /// be checked by capability filters.
93    ///
94    /// # Example
95    ///
96    /// ```rust
97    /// use tower_mcp::SessionState;
98    ///
99    /// let session = SessionState::new();
100    /// session.insert(42u32);
101    /// assert_eq!(session.get::<u32>(), Some(42));
102    /// ```
103    pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) {
104        if let Ok(mut ext) = self.extensions.write() {
105            ext.insert(val);
106        }
107    }
108
109    /// Get a cloned value from the session extensions.
110    ///
111    /// Returns `None` if no value of the given type has been inserted or if
112    /// the lock cannot be acquired.
113    ///
114    /// # Example
115    ///
116    /// ```rust
117    /// use tower_mcp::SessionState;
118    ///
119    /// let session = SessionState::new();
120    /// session.insert("hello".to_string());
121    /// assert_eq!(session.get::<String>(), Some("hello".to_string()));
122    /// assert_eq!(session.get::<u32>(), None);
123    /// ```
124    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
125        self.extensions
126            .read()
127            .ok()
128            .and_then(|ext| ext.get::<T>().cloned())
129    }
130
131    /// Get the current session phase
132    pub fn phase(&self) -> SessionPhase {
133        SessionPhase::from(self.phase.load(Ordering::Acquire))
134    }
135
136    /// Check if the session is initialized (operation phase)
137    pub fn is_initialized(&self) -> bool {
138        self.phase() == SessionPhase::Initialized
139    }
140
141    /// Transition from Uninitialized to Initializing.
142    /// Called after responding to an `initialize` request.
143    /// Returns true if the transition was successful.
144    pub fn mark_initializing(&self) -> bool {
145        self.phase
146            .compare_exchange(
147                SessionPhase::Uninitialized as u8,
148                SessionPhase::Initializing as u8,
149                Ordering::AcqRel,
150                Ordering::Acquire,
151            )
152            .is_ok()
153    }
154
155    /// Transition to Initialized phase.
156    /// Called when receiving an `initialized` notification.
157    ///
158    /// Accepts transitions from both `Initializing` and `Uninitialized` states.
159    /// The `Uninitialized → Initialized` path handles a race condition in HTTP
160    /// transports where the client sends the `initialized` notification before
161    /// the server has finished processing the `initialize` request (the session
162    /// is stored in `Uninitialized` state before the request is dispatched).
163    ///
164    /// Returns true if the transition was successful.
165    pub fn mark_initialized(&self) -> bool {
166        // Try the expected path first: Initializing → Initialized
167        if self
168            .phase
169            .compare_exchange(
170                SessionPhase::Initializing as u8,
171                SessionPhase::Initialized as u8,
172                Ordering::AcqRel,
173                Ordering::Acquire,
174            )
175            .is_ok()
176        {
177            return true;
178        }
179
180        // Handle the race: Uninitialized → Initialized
181        // This occurs when the initialized notification arrives before
182        // the initialize request has been fully processed.
183        self.phase
184            .compare_exchange(
185                SessionPhase::Uninitialized as u8,
186                SessionPhase::Initialized as u8,
187                Ordering::AcqRel,
188                Ordering::Acquire,
189            )
190            .is_ok()
191    }
192
193    /// Check if a request method is allowed in the current phase.
194    /// Per spec:
195    /// - Before initialization: only `initialize` and `ping` are valid
196    /// - During all phases: `ping` is always valid
197    pub fn is_request_allowed(&self, method: &str) -> bool {
198        match self.phase() {
199            SessionPhase::Uninitialized => {
200                matches!(method, "initialize" | "ping")
201            }
202            SessionPhase::Initializing | SessionPhase::Initialized => true,
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_session_lifecycle() {
213        let session = SessionState::new();
214
215        // Initial state
216        assert_eq!(session.phase(), SessionPhase::Uninitialized);
217        assert!(!session.is_initialized());
218
219        // Only initialize and ping allowed
220        assert!(session.is_request_allowed("initialize"));
221        assert!(session.is_request_allowed("ping"));
222        assert!(!session.is_request_allowed("tools/list"));
223
224        // Transition to initializing
225        assert!(session.mark_initializing());
226        assert_eq!(session.phase(), SessionPhase::Initializing);
227        assert!(!session.is_initialized());
228
229        // Can't mark initializing again
230        assert!(!session.mark_initializing());
231
232        // All requests allowed during initializing
233        assert!(session.is_request_allowed("tools/list"));
234
235        // Transition to initialized
236        assert!(session.mark_initialized());
237        assert_eq!(session.phase(), SessionPhase::Initialized);
238        assert!(session.is_initialized());
239
240        // Can't mark initialized again
241        assert!(!session.mark_initialized());
242    }
243
244    #[test]
245    fn test_session_clone_shares_state() {
246        let session1 = SessionState::new();
247        let session2 = session1.clone();
248
249        session1.mark_initializing();
250        assert_eq!(session2.phase(), SessionPhase::Initializing);
251
252        session2.mark_initialized();
253        assert_eq!(session1.phase(), SessionPhase::Initialized);
254    }
255
256    #[test]
257    fn test_session_extensions_insert_and_get() {
258        let session = SessionState::new();
259
260        // Insert and retrieve a value
261        session.insert(42u32);
262        assert_eq!(session.get::<u32>(), Some(42));
263
264        // Different type returns None
265        assert_eq!(session.get::<String>(), None);
266    }
267
268    #[test]
269    fn test_session_extensions_overwrite() {
270        let session = SessionState::new();
271
272        session.insert(42u32);
273        assert_eq!(session.get::<u32>(), Some(42));
274
275        // Overwrite with new value
276        session.insert(100u32);
277        assert_eq!(session.get::<u32>(), Some(100));
278    }
279
280    #[test]
281    fn test_session_extensions_multiple_types() {
282        let session = SessionState::new();
283
284        session.insert(42u32);
285        session.insert("hello".to_string());
286        session.insert(true);
287
288        assert_eq!(session.get::<u32>(), Some(42));
289        assert_eq!(session.get::<String>(), Some("hello".to_string()));
290        assert_eq!(session.get::<bool>(), Some(true));
291    }
292
293    #[test]
294    fn test_session_extensions_shared_across_clones() {
295        let session1 = SessionState::new();
296        let session2 = session1.clone();
297
298        // Insert in one clone
299        session1.insert(42u32);
300
301        // Should be visible in the other
302        assert_eq!(session2.get::<u32>(), Some(42));
303
304        // Insert in the second clone
305        session2.insert("world".to_string());
306
307        // Should be visible in the first
308        assert_eq!(session1.get::<String>(), Some("world".to_string()));
309    }
310
311    #[test]
312    fn test_mark_initialized_from_uninitialized() {
313        let session = SessionState::new();
314
315        // Start in Uninitialized, skip straight to Initialized
316        // This handles the race where `initialized` notification arrives
317        // before the `initialize` request is fully processed.
318        assert_eq!(session.phase(), SessionPhase::Uninitialized);
319        assert!(session.mark_initialized());
320        assert_eq!(session.phase(), SessionPhase::Initialized);
321        assert!(session.is_initialized());
322
323        // All requests allowed
324        assert!(session.is_request_allowed("tools/list"));
325        assert!(session.is_request_allowed("ping"));
326    }
327
328    #[test]
329    fn test_mark_initialized_idempotent_when_already_initialized() {
330        let session = SessionState::new();
331
332        // Normal lifecycle
333        session.mark_initializing();
334        session.mark_initialized();
335        assert_eq!(session.phase(), SessionPhase::Initialized);
336
337        // Calling mark_initialized again should fail (already in target state)
338        assert!(!session.mark_initialized());
339        assert_eq!(session.phase(), SessionPhase::Initialized);
340    }
341
342    #[test]
343    fn test_session_extensions_custom_type() {
344        #[derive(Debug, Clone, PartialEq)]
345        struct UserClaims {
346            user_id: String,
347            role: String,
348        }
349
350        let session = SessionState::new();
351
352        session.insert(UserClaims {
353            user_id: "user123".to_string(),
354            role: "admin".to_string(),
355        });
356
357        let claims = session.get::<UserClaims>();
358        assert!(claims.is_some());
359        let claims = claims.unwrap();
360        assert_eq!(claims.user_id, "user123");
361        assert_eq!(claims.role, "admin");
362    }
363}