turul_http_mcp_server/
mcp_session.rs

1//! Streamable HTTP session management for MCP transport
2//!
3//! This module implements a lightweight session management system
4//! for MCP streamable HTTP with SSE support.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8use tokio::sync::{Mutex, broadcast};
9use tracing::{info, warn};
10use uuid::Uuid;
11
12use crate::protocol::McpProtocolVersion;
13
14/// Per-session state
15pub struct Session {
16    /// Broadcast sender for SSE notifications
17    pub sender: broadcast::Sender<String>,
18    /// When the session was created (touched on access)
19    pub created: Instant,
20    /// MCP protocol version for this session
21    pub version: McpProtocolVersion,
22}
23
24impl Session {
25    /// Update the "last used" timestamp to now (touch mechanism)
26    pub fn touch(&mut self) {
27        self.created = Instant::now();
28    }
29}
30
31/// A handle to a session's message stream
32pub struct SessionHandle {
33    pub session_id: String,
34    pub receiver: broadcast::Receiver<String>,
35}
36
37pub type SessionMap = Mutex<HashMap<String, Session>>;
38
39lazy_static::lazy_static! {
40    /// Internal map: session_id -> Session (sender, creation time, MCP Version)
41    static ref SESSIONS: SessionMap = Mutex::new(HashMap::new());
42}
43
44/// Create a brand new session.
45/// Returns the `session_id` and a `Receiver<String>` you can use to drive an SSE stream.
46pub async fn new_session(mcp_version: McpProtocolVersion) -> SessionHandle {
47    let session_id = Uuid::now_v7().to_string();
48    // 128-slot broadcast channel for JSON-RPC notifications
49    let (sender, receiver) = broadcast::channel(128);
50    let session = Session {
51        sender: sender.clone(),
52        created: Instant::now(),
53        version: mcp_version,
54    };
55    SESSIONS.lock().await.insert(session_id.clone(), session);
56    SessionHandle {
57        session_id,
58        receiver,
59    }
60}
61
62/// Fetch and "touch" the sender for this session, extending its lifetime.
63/// Returns `None` if the session does not exist or has already expired.
64pub async fn get_sender(session_id: &str) -> Option<broadcast::Sender<String>> {
65    let mut sessions = SESSIONS.lock().await;
66    if let Some(session) = sessions.get_mut(session_id) {
67        // Bump the creation time to now
68        session.touch();
69        // Return a clone of the sender
70        return Some(session.sender.clone());
71    }
72    None
73}
74
75/// Get a session's receiver for SSE streaming
76pub async fn get_receiver(session_id: &str) -> Option<broadcast::Receiver<String>> {
77    let mut sessions = SESSIONS.lock().await;
78    if let Some(session) = sessions.get_mut(session_id) {
79        session.touch();
80        return Some(session.sender.subscribe());
81    }
82    None
83}
84
85/// Check if a session exists and touch it
86pub async fn session_exists(session_id: &str) -> bool {
87    let mut sessions = SESSIONS.lock().await;
88    if let Some(session) = sessions.get_mut(session_id) {
89        session.touch();
90        true
91    } else {
92        false
93    }
94}
95
96/// Explicitly remove/terminate a session.
97/// You can call this on client disconnect or after HTTP GET SSE finishes.
98pub async fn remove_session(session_id: &str) -> bool {
99    SESSIONS.lock().await.remove(session_id).is_some()
100}
101
102/// Expire any sessions older than `max_age`, dropping them from the map.
103pub async fn expire_old(max_age: Duration) {
104    let cutoff = Instant::now() - max_age;
105    let mut sessions = SESSIONS.lock().await;
106    sessions.retain(|sid, session| {
107        let alive = session.created >= cutoff;
108        if !alive {
109            info!("Session {} expired", sid);
110        }
111        alive
112    });
113}
114
115/// Send the given JSON-RPC message to a specific session
116pub async fn send_to_session(session_id: &str, message: String) -> bool {
117    if let Some(sender) = get_sender(session_id).await {
118        sender.send(message).is_ok()
119    } else {
120        false
121    }
122}
123
124/// Send the given JSON-RPC message to every active session.
125pub async fn broadcast_to_all(message: String) {
126    let sessions = SESSIONS.lock().await;
127    for (sid, session) in sessions.iter() {
128        warn!("Sending message: {} to session {}", message, sid);
129        // Ignore errors (no subscribers, lag, etc.)
130        let _ = session.sender.send(message.clone());
131    }
132}
133
134/// Disconnect all sessions
135pub async fn disconnect_all() {
136    let mut sessions = SESSIONS.lock().await;
137    // Just clear the map: dropping each Session.sender
138    sessions.clear();
139    info!("All sessions have been disconnected");
140}
141
142/// Get count of active sessions
143pub async fn session_count() -> usize {
144    let sessions = SESSIONS.lock().await;
145    sessions.len()
146}
147
148/// Spawn session cleanup task for automatic session management
149pub fn spawn_session_cleanup() {
150    tokio::spawn(async {
151        let mut interval = tokio::time::interval(Duration::from_secs(60));
152        loop {
153            interval.tick().await;
154            expire_old(Duration::from_secs(30 * 60)).await; // 30 minutes
155        }
156    });
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[tokio::test]
164    async fn test_session_lifecycle() {
165        let handle = new_session(McpProtocolVersion::V2025_06_18).await;
166        let session_id = handle.session_id.clone();
167
168        // Session should exist
169        assert!(session_exists(&session_id).await);
170
171        // Should be able to get sender
172        assert!(get_sender(&session_id).await.is_some());
173
174        // Should be able to get receiver
175        assert!(get_receiver(&session_id).await.is_some());
176
177        // Remove session
178        assert!(remove_session(&session_id).await);
179
180        // Should no longer exist
181        assert!(!session_exists(&session_id).await);
182    }
183
184    #[tokio::test]
185    async fn test_session_messaging() {
186        let handle = new_session(McpProtocolVersion::V2025_06_18).await;
187        let session_id = handle.session_id.clone();
188
189        // Send message to session
190        let message = r#"{"method":"test","params":{}}"#.to_string();
191        assert!(send_to_session(&session_id, message.clone()).await);
192
193        // Receive message
194        let mut receiver = handle.receiver;
195        let received = tokio::time::timeout(Duration::from_millis(100), receiver.recv()).await;
196
197        assert!(received.is_ok());
198        assert_eq!(received.unwrap().unwrap(), message);
199    }
200}