Skip to main content

tower_mcp/
session.rs

1//! MCP session state management
2//!
3//! Tracks the lifecycle state of an MCP connection as per the specification.
4//! The session progresses through phases: Uninitialized -> Initializing -> Initialized.
5
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, Ordering};
8
9/// Session lifecycle phase
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u8)]
12pub enum SessionPhase {
13    /// Initial state - only `initialize` and `ping` requests are valid
14    Uninitialized = 0,
15    /// Server has responded to `initialize`, waiting for `initialized` notification
16    Initializing = 1,
17    /// `initialized` notification received, normal operation
18    Initialized = 2,
19}
20
21impl From<u8> for SessionPhase {
22    fn from(value: u8) -> Self {
23        match value {
24            0 => SessionPhase::Uninitialized,
25            1 => SessionPhase::Initializing,
26            2 => SessionPhase::Initialized,
27            _ => SessionPhase::Uninitialized,
28        }
29    }
30}
31
32/// Shared session state that can be cloned across requests.
33///
34/// Uses atomic operations for thread-safe state transitions.
35#[derive(Clone)]
36pub struct SessionState {
37    phase: Arc<AtomicU8>,
38}
39
40impl Default for SessionState {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl SessionState {
47    /// Create a new session in the Uninitialized phase
48    pub fn new() -> Self {
49        Self {
50            phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
51        }
52    }
53
54    /// Get the current session phase
55    pub fn phase(&self) -> SessionPhase {
56        SessionPhase::from(self.phase.load(Ordering::Acquire))
57    }
58
59    /// Check if the session is initialized (operation phase)
60    pub fn is_initialized(&self) -> bool {
61        self.phase() == SessionPhase::Initialized
62    }
63
64    /// Transition from Uninitialized to Initializing.
65    /// Called after responding to an `initialize` request.
66    /// Returns true if the transition was successful.
67    pub fn mark_initializing(&self) -> bool {
68        self.phase
69            .compare_exchange(
70                SessionPhase::Uninitialized as u8,
71                SessionPhase::Initializing as u8,
72                Ordering::AcqRel,
73                Ordering::Acquire,
74            )
75            .is_ok()
76    }
77
78    /// Transition from Initializing to Initialized.
79    /// Called when receiving an `initialized` notification.
80    /// Returns true if the transition was successful.
81    pub fn mark_initialized(&self) -> bool {
82        self.phase
83            .compare_exchange(
84                SessionPhase::Initializing as u8,
85                SessionPhase::Initialized as u8,
86                Ordering::AcqRel,
87                Ordering::Acquire,
88            )
89            .is_ok()
90    }
91
92    /// Check if a request method is allowed in the current phase.
93    /// Per spec:
94    /// - Before initialization: only `initialize` and `ping` are valid
95    /// - During all phases: `ping` is always valid
96    pub fn is_request_allowed(&self, method: &str) -> bool {
97        match self.phase() {
98            SessionPhase::Uninitialized => {
99                matches!(method, "initialize" | "ping")
100            }
101            SessionPhase::Initializing | SessionPhase::Initialized => true,
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_session_lifecycle() {
112        let session = SessionState::new();
113
114        // Initial state
115        assert_eq!(session.phase(), SessionPhase::Uninitialized);
116        assert!(!session.is_initialized());
117
118        // Only initialize and ping allowed
119        assert!(session.is_request_allowed("initialize"));
120        assert!(session.is_request_allowed("ping"));
121        assert!(!session.is_request_allowed("tools/list"));
122
123        // Transition to initializing
124        assert!(session.mark_initializing());
125        assert_eq!(session.phase(), SessionPhase::Initializing);
126        assert!(!session.is_initialized());
127
128        // Can't mark initializing again
129        assert!(!session.mark_initializing());
130
131        // All requests allowed during initializing
132        assert!(session.is_request_allowed("tools/list"));
133
134        // Transition to initialized
135        assert!(session.mark_initialized());
136        assert_eq!(session.phase(), SessionPhase::Initialized);
137        assert!(session.is_initialized());
138
139        // Can't mark initialized again
140        assert!(!session.mark_initialized());
141    }
142
143    #[test]
144    fn test_session_clone_shares_state() {
145        let session1 = SessionState::new();
146        let session2 = session1.clone();
147
148        session1.mark_initializing();
149        assert_eq!(session2.phase(), SessionPhase::Initializing);
150
151        session2.mark_initialized();
152        assert_eq!(session1.phase(), SessionPhase::Initialized);
153    }
154}