turul_http_mcp_server/
stream_manager.rs

1//! Enhanced Stream Manager with MCP 2025-06-18 Resumability
2//!
3//! This module provides proper SSE stream management with:
4//! - Event IDs for resumability
5//! - Last-Event-ID header support
6//! - Per-session event targeting (not broadcast to all)
7//! - Event persistence and replay
8//! - Proper HTTP status codes and headers
9
10use std::sync::Arc;
11use std::collections::HashMap;
12use std::pin::Pin;
13use futures::{Stream, StreamExt};
14use hyper::{Response, StatusCode};
15use http_body_util::{StreamBody, BodyExt};
16use bytes::Bytes;
17use hyper::header::{CONTENT_TYPE, CACHE_CONTROL, ACCESS_CONTROL_ALLOW_ORIGIN};
18use serde_json::Value;
19use tokio::sync::{mpsc, RwLock};
20use tracing::{debug, info, error, warn};
21
22use turul_mcp_session_storage::SseEvent;
23
24/// Connection ID for tracking individual SSE streams
25pub type ConnectionId = String;
26
27/// Enhanced stream manager with resumability support (MCP spec compliant)
28pub struct StreamManager {
29    /// Session storage backend for persistence
30    storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
31    /// Per-session connections for real-time events (MCP compliant - no broadcasting)
32    connections: Arc<RwLock<HashMap<String, HashMap<ConnectionId, mpsc::Sender<SseEvent>>>>>,
33    /// Configuration
34    config: StreamConfig,
35    /// Unique instance ID for debugging
36    instance_id: String,
37}
38
39/// Configuration for stream management
40#[derive(Debug, Clone)]
41pub struct StreamConfig {
42    /// Channel buffer size for real-time broadcasting
43    pub channel_buffer_size: usize,
44    /// Maximum events to replay on reconnection
45    pub max_replay_events: usize,
46    /// Keep-alive interval in seconds
47    pub keepalive_interval_seconds: u64,
48    /// CORS configuration
49    pub cors_origin: String,
50}
51
52impl Default for StreamConfig {
53    fn default() -> Self {
54        Self {
55            channel_buffer_size: 1000,
56            max_replay_events: 100,
57            keepalive_interval_seconds: 30,
58            cors_origin: "*".to_string(),
59        }
60    }
61}
62
63/// SSE stream wrapper that formats events properly (MCP compliant - one connection per stream)
64pub struct SseStream {
65    /// Underlying event stream
66    stream: Option<Pin<Box<dyn Stream<Item = SseEvent> + Send>>>,
67    /// Session metadata
68    session_id: String,
69    /// Connection identifier (for MCP spec compliance)
70    connection_id: ConnectionId,
71}
72
73impl SseStream {
74    /// Get the session ID this stream belongs to
75    pub fn session_id(&self) -> &str {
76        &self.session_id
77    }
78
79    /// Get the connection ID for this stream
80    pub fn connection_id(&self) -> &str {
81        &self.connection_id
82    }
83
84    /// Get stream identifier for logging (session + connection)
85    pub fn stream_identifier(&self) -> String {
86        format!("{}:{}", self.session_id, self.connection_id)
87    }
88}
89
90impl Drop for SseStream {
91    fn drop(&mut self) {
92        debug!("๐Ÿ”ฅ DROP: SseStream - session={}, connection={}",
93               self.session_id, self.connection_id);
94        if self.stream.is_some() {
95            debug!("๐Ÿ”ฅ Stream still present during drop - this indicates early cleanup");
96        } else {
97            debug!("๐Ÿ”ฅ Stream was properly extracted before drop");
98        }
99    }
100}
101
102/// Error type for stream management
103#[derive(Debug, thiserror::Error)]
104pub enum StreamError {
105    #[error("Session not found: {0}")]
106    SessionNotFound(String),
107    #[error("Stream not found: session={0}, stream={1}")]
108    StreamNotFound(String, String),
109    #[error("Storage error: {0}")]
110    StorageError(String),
111    #[error("Connection error: {0}")]
112    ConnectionError(String),
113}
114
115impl StreamManager {
116    /// Create new stream manager with session storage backend
117    pub fn new(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>) -> Self {
118        Self::with_config(storage, StreamConfig::default())
119    }
120
121    /// Create stream manager with custom configuration
122    pub fn with_config(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>, config: StreamConfig) -> Self {
123        use uuid::Uuid;
124        let instance_id = Uuid::now_v7().to_string();
125        debug!("๐Ÿ”ง Creating StreamManager instance: {}", instance_id);
126        Self {
127            storage,
128            connections: Arc::new(RwLock::new(HashMap::new())),
129            config,
130            instance_id,
131        }
132    }
133
134    /// Handle SSE connection request with proper resumability
135    pub async fn handle_sse_connection(
136        &self,
137        session_id: String,
138        connection_id: ConnectionId,
139        last_event_id: Option<u64>,
140    ) -> Result<Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>, StreamError> {
141        // Verify session exists
142        if self.storage.get_session(&session_id).await
143            .map_err(|e| StreamError::StorageError(e.to_string()))?
144            .is_none()
145        {
146            return Err(StreamError::SessionNotFound(session_id));
147        }
148
149        // Create the SSE stream (one per connection, MCP compliant)
150        let sse_stream = self.create_sse_stream(session_id.clone(), connection_id.clone(), last_event_id).await?;
151
152        // Convert to HTTP response
153        let response = self.stream_to_response(sse_stream).await;
154
155        info!("Created SSE connection: session={}, connection={}, last_event_id={:?}",
156              session_id, connection_id, last_event_id);
157
158        Ok(response)
159    }
160
161    /// Create SSE stream with resumability support (MCP compliant - no broadcast)
162    async fn create_sse_stream(
163        &self,
164        session_id: String,
165        connection_id: ConnectionId,
166        last_event_id: Option<u64>,
167    ) -> Result<SseStream, StreamError> {
168        // Create mpsc channel for this specific connection (MCP compliant)
169        let (sender, mut receiver) = mpsc::channel(self.config.channel_buffer_size);
170
171        // Register this connection with the session
172        self.register_connection(&session_id, connection_id.clone(), sender).await;
173
174        // Create the combined stream
175        let storage = self.storage.clone();
176        let session_id_clone = session_id.clone();
177        let connection_id_clone = connection_id.clone();
178        let config = self.config.clone();
179
180        let combined_stream = async_stream::stream! {
181            // 1. First, yield any historical events (resumability)
182            if let Some(after_event_id) = last_event_id {
183                debug!("Replaying events after ID {} for session={}, connection={}",
184                       after_event_id, session_id_clone, connection_id_clone);
185
186                match storage.get_events_after(&session_id_clone, after_event_id).await {
187                    Ok(events) => {
188                        for event in events.into_iter().take(config.max_replay_events) {
189                            yield event;
190                        }
191                    },
192                    Err(e) => {
193                        error!("Failed to get historical events: {}", e);
194                        // Continue with real-time events even if historical replay fails
195                    }
196                }
197            }
198
199            // 2. Then, stream real-time events from dedicated channel
200            let mut keepalive_interval = tokio::time::interval(
201                tokio::time::Duration::from_secs(config.keepalive_interval_seconds)
202            );
203
204            loop {
205                tokio::select! {
206                    // Real-time events from this connection's channel
207                    event = receiver.recv() => {
208                        match event {
209                            Some(event) => {
210                                debug!("๐Ÿ“จ Received event for connection {}: {}", connection_id_clone, event.event_type);
211                                yield event;
212                            },
213                            None => {
214                                debug!("Connection channel closed for session={}, connection={}", session_id_clone, connection_id_clone);
215                                break;
216                            }
217                        }
218                    },
219
220                    // Keep-alive pings
221                    _ = keepalive_interval.tick() => {
222                        let keepalive_event = SseEvent {
223                            id: 0, // Keep-alive events don't need persistent IDs
224                            timestamp: chrono::Utc::now().timestamp_millis() as u64,
225                            event_type: "ping".to_string(),
226                            data: serde_json::json!({"type": "keepalive"}),
227                            retry: None,
228                        };
229                        yield keepalive_event;
230                    }
231                }
232            }
233
234            // Clean up connection when stream ends
235            debug!("๐Ÿงน Cleaning up connection: session={}, connection={}", session_id_clone, connection_id_clone);
236        };
237
238        Ok(SseStream {
239            stream: Some(Box::pin(combined_stream)),
240            session_id,
241            connection_id,
242        })
243    }
244
245    /// Register a new connection for a session (MCP compliant)
246    async fn register_connection(
247        &self,
248        session_id: &str,
249        connection_id: ConnectionId,
250        sender: mpsc::Sender<SseEvent>
251    ) {
252        let mut connections = self.connections.write().await;
253
254        debug!("[{}] ๐Ÿ” BEFORE registration: HashMap has {} sessions", self.instance_id, connections.len());
255        for (sid, conns) in connections.iter() {
256            debug!("[{}] ๐Ÿ” Existing session before: {} with {} connections", self.instance_id, sid, conns.len());
257        }
258
259        // Get or create session entry
260        let session_connections = connections.entry(session_id.to_string())
261            .or_insert_with(HashMap::new);
262
263        // Add this connection
264        session_connections.insert(connection_id.clone(), sender);
265
266        debug!("[{}] ๐Ÿ”— Registered connection: session={}, connection={}, total_connections={}",
267               self.instance_id, session_id, connection_id, session_connections.len());
268
269        debug!("[{}] ๐Ÿ” AFTER registration: HashMap has {} sessions", self.instance_id, connections.len());
270        for (sid, conns) in connections.iter() {
271            debug!("[{}] ๐Ÿ” Session after: {} with {} connections", self.instance_id, sid, conns.len());
272        }
273    }
274
275    /// Remove a connection when it's closed
276    pub async fn unregister_connection(&self, session_id: &str, connection_id: &ConnectionId) {
277        debug!("๐Ÿ”ด UNREGISTER called for session={}, connection={}", session_id, connection_id);
278        let mut connections = self.connections.write().await;
279
280        debug!("๐Ÿ” BEFORE unregister: HashMap has {} sessions", connections.len());
281
282        if let Some(session_connections) = connections.get_mut(session_id) {
283            if session_connections.remove(connection_id).is_some() {
284                debug!("๐Ÿ”Œ Unregistered connection: session={}, connection={}", session_id, connection_id);
285
286                // Clean up empty sessions
287                if session_connections.is_empty() {
288                    connections.remove(session_id);
289                    debug!("๐Ÿงน Removed empty session: {}", session_id);
290                }
291            }
292        }
293
294        debug!("๐Ÿ” AFTER unregister: HashMap has {} sessions", connections.len());
295    }
296
297    /// Convert SSE stream to HTTP response with proper headers
298    async fn stream_to_response(&self, mut sse_stream: SseStream) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
299        // Extract session info before moving the stream
300        let session_id = sse_stream.session_id().to_string();
301        let stream_identifier = sse_stream.stream_identifier();
302
303        // Log stream creation with session identifier
304        info!("Converting SSE stream to HTTP response: {}", stream_identifier);
305        debug!("Stream details: session_id={}", session_id);
306
307        // Transform events to SSE format and create proper HTTP frames
308        // Extract stream from Option wrapper
309        let stream = sse_stream.stream.take().expect("Stream should be present in SseStream");
310
311        let formatted_stream = stream.map(|event| {
312            let sse_formatted = event.format();
313            debug!("๐Ÿ“ก Streaming SSE event: id={}, event_type={}", event.id, event.event_type);
314            Ok(hyper::body::Frame::data(Bytes::from(sse_formatted)))
315        });
316
317        // Create streaming body from the actual event stream and box it
318        let body = StreamBody::new(formatted_stream).boxed_unsync();
319
320        // Build response with proper SSE headers for streaming
321        Response::builder()
322            .status(StatusCode::OK)
323            .header(CONTENT_TYPE, "text/event-stream")
324            .header(CACHE_CONTROL, "no-cache")
325            .header(ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
326            .header("Connection", "keep-alive")
327            .body(body)
328            .unwrap()
329    }
330
331    /// Send event to specific session (MCP compliant - ONE connection only)
332    pub async fn broadcast_to_session(
333        &self,
334        session_id: &str,
335        event_type: String,
336        data: Value,
337    ) -> Result<u64, StreamError> {
338        // Create the event
339        let event = SseEvent::new(event_type, data);
340
341        // Store event for resumability
342        let stored_event = self.storage.store_event(session_id, event).await
343            .map_err(|e| StreamError::StorageError(e.to_string()))?;
344
345        // DEBUG: Check connection state more thoroughly
346        let connections = self.connections.read().await;
347        debug!("[{}] ๐Ÿ” Checking connections for session {}: connections hashmap has {} sessions",
348               self.instance_id, session_id, connections.len());
349
350        if let Some(session_connections) = connections.get(session_id) {
351            debug!("๐Ÿ” Session {} found with {} connections", session_id, session_connections.len());
352
353            if !session_connections.is_empty() {
354                // Pick the FIRST available connection (MCP compliant)
355                let (selected_connection_id, selected_sender) = session_connections.iter().next().unwrap();
356
357                // Check if sender is closed
358                if selected_sender.is_closed() {
359                    warn!("๐Ÿ”Œ Sender is closed for connection: session={}, connection={}",
360                          session_id, selected_connection_id);
361                    debug!("๐Ÿ“ญ Connection sender was closed, event stored for reconnection");
362                } else {
363                    debug!("โœ… Sender is open, attempting to send to connection: session={}, connection={}",
364                           session_id, selected_connection_id);
365
366                    match selected_sender.try_send(stored_event.clone()) {
367                        Ok(()) => {
368                            info!("โœ… Sent notification to ONE connection: session={}, connection={}, event_id={}, method={}",
369                                  session_id, selected_connection_id, stored_event.id, stored_event.event_type);
370                        },
371                        Err(mpsc::error::TrySendError::Full(_)) => {
372                            warn!("โš ๏ธ Connection buffer full: session={}, connection={}", session_id, selected_connection_id);
373                            // Event is still stored for reconnection
374                        },
375                        Err(mpsc::error::TrySendError::Closed(_)) => {
376                            warn!("๐Ÿ”Œ Connection closed during send: session={}, connection={}", session_id, selected_connection_id);
377                            // Event is still stored for reconnection
378                        }
379                    }
380                }
381            } else {
382                debug!("๐Ÿ“ญ No active connections for session: {} (event stored for reconnection)", session_id);
383            }
384        } else {
385            debug!("๐Ÿ“ญ No connections registered for session: {} (event stored for reconnection)", session_id);
386
387            // DEBUG: List all sessions in connections
388            for (sid, conns) in connections.iter() {
389                debug!("๐Ÿ” Available session: {} with {} connections", sid, conns.len());
390            }
391        }
392
393        Ok(stored_event.id)
394    }
395
396    /// Broadcast to all sessions (for server-wide notifications)
397    pub async fn broadcast_to_all_sessions(
398        &self,
399        event_type: String,
400        data: Value,
401    ) -> Result<Vec<String>, StreamError> {
402        // Get all session IDs
403        let session_ids = self.storage.list_sessions().await
404            .map_err(|e| StreamError::StorageError(e.to_string()))?;
405
406        let mut failed_sessions = Vec::new();
407
408        for session_id in session_ids {
409            if let Err(e) = self.broadcast_to_session(&session_id, event_type.clone(), data.clone()).await {
410                error!("Failed to broadcast to session {}: {}", session_id, e);
411                failed_sessions.push(session_id);
412            }
413        }
414
415        Ok(failed_sessions)
416    }
417
418    /// Clean up closed connections
419    pub async fn cleanup_connections(&self) -> usize {
420        debug!("๐Ÿงน CLEANUP_CONNECTIONS called");
421        let mut connections = self.connections.write().await;
422        let mut total_cleaned = 0;
423
424        debug!("๐Ÿ” BEFORE cleanup: HashMap has {} sessions", connections.len());
425
426        // Clean up closed connections
427        connections.retain(|session_id, session_connections| {
428            let initial_count = session_connections.len();
429
430            // Remove closed connections
431            session_connections.retain(|connection_id, sender| {
432                if sender.is_closed() {
433                    debug!("๐Ÿงน Cleaned up closed connection: session={}, connection={}", session_id, connection_id);
434                    false
435                } else {
436                    true
437                }
438            });
439
440            let cleaned_count = initial_count - session_connections.len();
441            total_cleaned += cleaned_count;
442
443            // Keep session if it has active connections
444            !session_connections.is_empty()
445        });
446
447        if total_cleaned > 0 {
448            info!("Cleaned up {} inactive connections", total_cleaned);
449        }
450
451        total_cleaned
452    }
453
454    /// Create SSE stream for POST requests (MCP Streamable HTTP)
455    pub async fn create_post_sse_stream(
456        &self,
457        session_id: String,
458        response: turul_mcp_json_rpc_server::JsonRpcResponse,
459    ) -> Result<hyper::Response<http_body_util::Full<bytes::Bytes>>, StreamError> {
460        // Verify session exists
461        if self.storage.get_session(&session_id).await
462            .map_err(|e| StreamError::StorageError(e.to_string()))?
463            .is_none()
464        {
465            return Err(StreamError::SessionNotFound(session_id));
466        }
467
468        info!("Creating POST SSE stream for session: {}", session_id);
469
470        // Create the SSE response body
471        let response_json = serde_json::to_string(&response)
472            .map_err(|e| StreamError::StorageError(format!("Failed to serialize response: {}", e)))?;
473
474        // For MCP Streamable HTTP, we create an SSE stream that:
475        // 1. Sends any pending notifications for this session
476        // 2. Sends the JSON-RPC response as an SSE event
477        // 3. Closes the stream
478
479        let mut sse_content = String::new();
480
481        // 1. Include recent notifications that may have been generated during tool execution
482        // Note: Since tool notifications are processed asynchronously, we need to wait a moment
483        // and then check for recent events to include in the POST SSE response
484        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
485
486        if let Ok(events) = self.storage.get_recent_events(&session_id, 10).await {
487            for event in events {
488                // Convert stored SSE event to notification JSON-RPC format
489                if event.event_type != "ping" { // Skip keepalive events
490                    let notification_sse = format!(
491                        "event: data\ndata: {}\n\n",
492                        event.data
493                    );
494                    sse_content.push_str(&notification_sse);
495                    debug!("๐Ÿ“ค Including notification in POST SSE stream: event_type={}", event.event_type);
496                }
497            }
498        }
499
500        // 2. Add the JSON-RPC tool response
501        let response_sse = format!(
502            "event: data\ndata: {}\n\n",
503            response_json
504        );
505        sse_content.push_str(&response_sse);
506
507        debug!("๐Ÿ“ก POST SSE response created: session={}, content_length={}", session_id, sse_content.len());
508
509        // Build response with proper SSE headers including MCP session ID
510        Ok(hyper::Response::builder()
511            .status(hyper::StatusCode::OK)
512            .header(hyper::header::CONTENT_TYPE, "text/event-stream")
513            .header(hyper::header::CACHE_CONTROL, "no-cache")
514            .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
515            .header("Connection", "keep-alive")
516            .header("Mcp-Session-Id", &session_id)
517            .body(http_body_util::Full::new(bytes::Bytes::from(sse_content)))
518            .unwrap())
519    }
520
521    /// Get statistics about active streams
522    pub async fn get_stats(&self) -> StreamStats {
523        let connections = self.connections.read().await;
524        let session_count = self.storage.session_count().await.unwrap_or(0);
525        let event_count = self.storage.event_count().await.unwrap_or(0);
526
527        // Count total active connections
528        let total_connections: usize = connections.values()
529            .map(|session_connections| session_connections.len())
530            .sum();
531
532        StreamStats {
533            active_broadcasters: total_connections, // Now tracks active connections
534            total_sessions: session_count,
535            total_events: event_count,
536            channel_buffer_size: self.config.channel_buffer_size,
537        }
538    }
539}
540
541impl Drop for StreamManager {
542    fn drop(&mut self) {
543        debug!("๐Ÿ”ฅ DROP: StreamManager instance {} - this may cause connection loss!",
544               self.instance_id);
545        debug!("๐Ÿ”ฅ If this appears during request processing, it indicates architecture problem");
546    }
547}
548
549/// Stream manager statistics
550#[derive(Debug, Clone)]
551pub struct StreamStats {
552    pub active_broadcasters: usize,
553    pub total_sessions: usize,
554    pub total_events: usize,
555    pub channel_buffer_size: usize,
556}
557
558// Helper to create async stream
559#[cfg(not(test))]
560use async_stream;
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use turul_mcp_session_storage::InMemorySessionStorage;
566    use turul_mcp_protocol::ServerCapabilities;
567
568    #[tokio::test]
569    async fn test_stream_manager_creation() {
570        let storage = Arc::new(InMemorySessionStorage::new());
571        let manager = StreamManager::new(storage);
572
573        let stats = manager.get_stats().await;
574        assert_eq!(stats.active_broadcasters, 0);
575        assert_eq!(stats.total_sessions, 0);
576    }
577
578    #[tokio::test]
579    async fn test_broadcast_to_session() {
580        let storage = Arc::new(InMemorySessionStorage::new());
581        let manager = StreamManager::new(storage.clone());
582
583        // Create a session
584        let session = storage.create_session(ServerCapabilities::default()).await.unwrap();
585        let session_id = session.session_id.clone();
586
587        // Broadcast an event
588        let event_id = manager.broadcast_to_session(
589            &session_id,
590            "test".to_string(),
591            serde_json::json!({"message": "test"})
592        ).await.unwrap();
593
594        assert!(event_id > 0);
595
596        // Verify event was stored
597        let events = storage.get_events_after(&session_id, 0).await.unwrap();
598        assert_eq!(events.len(), 1);
599        assert_eq!(events[0].id, event_id);
600    }
601}