turul_http_mcp_server/
stream_manager.rs

1//! Enhanced Stream Manager with MCP 2025-06-18 Resumability
2//!
3//! This module provides proper SSE stream management with:
4//! - Event IDs for resumability
5//! - Last-Event-ID header support
6//! - Per-session event targeting (not broadcast to all)
7//! - Event persistence and replay
8//! - Proper HTTP status codes and headers
9
10use std::sync::Arc;
11use std::collections::HashMap;
12use std::pin::Pin;
13use futures::{Stream, StreamExt};
14use hyper::{Response, StatusCode};
15use http_body_util::{StreamBody, BodyExt};
16use bytes::Bytes;
17use hyper::header::{CONTENT_TYPE, CACHE_CONTROL, ACCESS_CONTROL_ALLOW_ORIGIN};
18use serde_json::Value;
19use tokio::sync::{mpsc, RwLock};
20use tracing::{debug, info, error, warn};
21
22use turul_mcp_session_storage::SseEvent;
23
24/// Connection ID for tracking individual SSE streams
25pub type ConnectionId = String;
26
27/// Enhanced stream manager with resumability support (MCP spec compliant)
28pub struct StreamManager {
29    /// Session storage backend for persistence
30    storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
31    /// Per-session connections for real-time events (MCP compliant - no broadcasting)
32    connections: Arc<RwLock<HashMap<String, HashMap<ConnectionId, mpsc::Sender<SseEvent>>>>>,
33    /// Configuration
34    config: StreamConfig,
35    /// Unique instance ID for debugging
36    instance_id: String,
37}
38
39/// Configuration for stream management
40#[derive(Debug, Clone)]
41pub struct StreamConfig {
42    /// Channel buffer size for real-time broadcasting
43    pub channel_buffer_size: usize,
44    /// Maximum events to replay on reconnection
45    pub max_replay_events: usize,
46    /// Keep-alive interval in seconds
47    pub keepalive_interval_seconds: u64,
48    /// CORS configuration
49    pub cors_origin: String,
50}
51
52impl Default for StreamConfig {
53    fn default() -> Self {
54        Self {
55            channel_buffer_size: 1000,
56            max_replay_events: 100,
57            keepalive_interval_seconds: 30,
58            cors_origin: "*".to_string(),
59        }
60    }
61}
62
63/// SSE stream wrapper that formats events properly (MCP compliant - one connection per stream)
64pub struct SseStream {
65    /// Underlying event stream
66    stream: Option<Pin<Box<dyn Stream<Item = SseEvent> + Send>>>,
67    /// Session metadata
68    session_id: String,
69    /// Connection identifier (for MCP spec compliance)
70    connection_id: ConnectionId,
71}
72
73impl SseStream {
74    /// Get the session ID this stream belongs to
75    pub fn session_id(&self) -> &str {
76        &self.session_id
77    }
78
79    /// Get the connection ID for this stream
80    pub fn connection_id(&self) -> &str {
81        &self.connection_id
82    }
83
84    /// Get stream identifier for logging (session + connection)
85    pub fn stream_identifier(&self) -> String {
86        format!("{}:{}", self.session_id, self.connection_id)
87    }
88}
89
90impl Drop for SseStream {
91    fn drop(&mut self) {
92        debug!("๐Ÿ”ฅ DROP: SseStream - session={}, connection={}",
93               self.session_id, self.connection_id);
94        if self.stream.is_some() {
95            debug!("๐Ÿ”ฅ Stream still present during drop - this indicates early cleanup");
96        } else {
97            debug!("๐Ÿ”ฅ Stream was properly extracted before drop");
98        }
99    }
100}
101
102/// Error type for stream management
103#[derive(Debug, thiserror::Error)]
104pub enum StreamError {
105    #[error("Session not found: {0}")]
106    SessionNotFound(String),
107    #[error("Stream not found: session={0}, stream={1}")]
108    StreamNotFound(String, String),
109    #[error("Storage error: {0}")]
110    StorageError(String),
111    #[error("Connection error: {0}")]
112    ConnectionError(String),
113    #[error("No connections available for session: {0}")]
114    NoConnections(String),
115}
116
117impl StreamManager {
118    /// Create new stream manager with session storage backend
119    pub fn new(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>) -> Self {
120        Self::with_config(storage, StreamConfig::default())
121    }
122
123    /// Create stream manager with custom configuration
124    pub fn with_config(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>, config: StreamConfig) -> Self {
125        use uuid::Uuid;
126        let instance_id = Uuid::now_v7().to_string();
127        debug!("๐Ÿ”ง Creating StreamManager instance: {}", instance_id);
128        Self {
129            storage,
130            connections: Arc::new(RwLock::new(HashMap::new())),
131            config,
132            instance_id,
133        }
134    }
135
136    /// Handle SSE connection request with proper resumability
137    pub async fn handle_sse_connection(
138        &self,
139        session_id: String,
140        connection_id: ConnectionId,
141        last_event_id: Option<u64>,
142    ) -> Result<Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>, StreamError> {
143        // Verify session exists
144        if self.storage.get_session(&session_id).await
145            .map_err(|e| StreamError::StorageError(e.to_string()))?
146            .is_none()
147        {
148            return Err(StreamError::SessionNotFound(session_id));
149        }
150
151        // Create the SSE stream (one per connection, MCP compliant)
152        let sse_stream = self.create_sse_stream(session_id.clone(), connection_id.clone(), last_event_id).await?;
153
154        // Convert to HTTP response
155        let response = self.stream_to_response(sse_stream).await;
156
157        info!("Created SSE connection: session={}, connection={}, last_event_id={:?}",
158              session_id, connection_id, last_event_id);
159
160        Ok(response)
161    }
162
163    /// Create SSE stream with resumability support (MCP compliant - no broadcast)
164    async fn create_sse_stream(
165        &self,
166        session_id: String,
167        connection_id: ConnectionId,
168        last_event_id: Option<u64>,
169    ) -> Result<SseStream, StreamError> {
170        // Create mpsc channel for this specific connection (MCP compliant)
171        let (sender, mut receiver) = mpsc::channel(self.config.channel_buffer_size);
172
173        // Register this connection with the session
174        self.register_connection(&session_id, connection_id.clone(), sender).await;
175
176        // Create the combined stream
177        let storage = self.storage.clone();
178        let session_id_clone = session_id.clone();
179        let connection_id_clone = connection_id.clone();
180        let config = self.config.clone();
181
182        let combined_stream = async_stream::stream! {
183            // 1. First, yield any historical events (resumability)
184            if let Some(after_event_id) = last_event_id {
185                debug!("Replaying events after ID {} for session={}, connection={}",
186                       after_event_id, session_id_clone, connection_id_clone);
187
188                match storage.get_events_after(&session_id_clone, after_event_id).await {
189                    Ok(events) => {
190                        for event in events.into_iter().take(config.max_replay_events) {
191                            yield event;
192                        }
193                    },
194                    Err(e) => {
195                        error!("Failed to get historical events: {}", e);
196                        // Continue with real-time events even if historical replay fails
197                    }
198                }
199            }
200
201            // 2. Then, stream real-time events from dedicated channel
202            let mut keepalive_interval = tokio::time::interval(
203                tokio::time::Duration::from_secs(config.keepalive_interval_seconds)
204            );
205
206            loop {
207                tokio::select! {
208                    // Real-time events from this connection's channel
209                    event = receiver.recv() => {
210                        match event {
211                            Some(event) => {
212                                debug!("๐Ÿ“จ Received event for connection {}: {}", connection_id_clone, event.event_type);
213                                yield event;
214                            },
215                            None => {
216                                debug!("Connection channel closed for session={}, connection={}", session_id_clone, connection_id_clone);
217                                break;
218                            }
219                        }
220                    },
221
222                    // Keep-alive pings
223                    _ = keepalive_interval.tick() => {
224                        let keepalive_event = SseEvent {
225                            id: 0, // Keep-alive events don't need persistent IDs
226                            timestamp: chrono::Utc::now().timestamp_millis() as u64,
227                            event_type: "ping".to_string(),
228                            data: serde_json::json!({"type": "keepalive"}),
229                            retry: None,
230                        };
231                        yield keepalive_event;
232                    }
233                }
234            }
235
236            // Clean up connection when stream ends
237            debug!("๐Ÿงน Cleaning up connection: session={}, connection={}", session_id_clone, connection_id_clone);
238        };
239
240        Ok(SseStream {
241            stream: Some(Box::pin(combined_stream)),
242            session_id,
243            connection_id,
244        })
245    }
246
247    /// Register a new connection for a session (MCP compliant)
248    async fn register_connection(
249        &self,
250        session_id: &str,
251        connection_id: ConnectionId,
252        sender: mpsc::Sender<SseEvent>
253    ) {
254        let mut connections = self.connections.write().await;
255
256        debug!("[{}] ๐Ÿ” BEFORE registration: HashMap has {} sessions", self.instance_id, connections.len());
257        for (sid, conns) in connections.iter() {
258            debug!("[{}] ๐Ÿ” Existing session before: {} with {} connections", self.instance_id, sid, conns.len());
259        }
260
261        // Get or create session entry
262        let session_connections = connections.entry(session_id.to_string())
263            .or_insert_with(HashMap::new);
264
265        // Add this connection
266        session_connections.insert(connection_id.clone(), sender);
267
268        debug!("[{}] ๐Ÿ”— Registered connection: session={}, connection={}, total_connections={}",
269               self.instance_id, session_id, connection_id, session_connections.len());
270
271        debug!("[{}] ๐Ÿ” AFTER registration: HashMap has {} sessions", self.instance_id, connections.len());
272        for (sid, conns) in connections.iter() {
273            debug!("[{}] ๐Ÿ” Session after: {} with {} connections", self.instance_id, sid, conns.len());
274        }
275    }
276
277    /// Remove a connection when it's closed
278    pub async fn unregister_connection(&self, session_id: &str, connection_id: &ConnectionId) {
279        debug!("๐Ÿ”ด UNREGISTER called for session={}, connection={}", session_id, connection_id);
280        let mut connections = self.connections.write().await;
281
282        debug!("๐Ÿ” BEFORE unregister: HashMap has {} sessions", connections.len());
283
284        if let Some(session_connections) = connections.get_mut(session_id) {
285            if session_connections.remove(connection_id).is_some() {
286                debug!("๐Ÿ”Œ Unregistered connection: session={}, connection={}", session_id, connection_id);
287
288                // Clean up empty sessions
289                if session_connections.is_empty() {
290                    connections.remove(session_id);
291                    debug!("๐Ÿงน Removed empty session: {}", session_id);
292                }
293            }
294        }
295
296        debug!("๐Ÿ” AFTER unregister: HashMap has {} sessions", connections.len());
297    }
298
299    /// Convert SSE stream to HTTP response with proper headers
300    async fn stream_to_response(&self, mut sse_stream: SseStream) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
301        // Extract session info before moving the stream
302        let session_id = sse_stream.session_id().to_string();
303        let stream_identifier = sse_stream.stream_identifier();
304
305        // Log stream creation with session identifier
306        info!("Converting SSE stream to HTTP response: {}", stream_identifier);
307        debug!("Stream details: session_id={}", session_id);
308
309        // Transform events to SSE format and create proper HTTP frames
310        // Extract stream from Option wrapper
311        let stream = sse_stream.stream.take().expect("Stream should be present in SseStream");
312
313        let formatted_stream = stream.map(|event| {
314            let sse_formatted = event.format();
315            debug!("๐Ÿ“ก Streaming SSE event: id={}, event_type={}", event.id, event.event_type);
316            Ok(hyper::body::Frame::data(Bytes::from(sse_formatted)))
317        });
318
319        // Create streaming body from the actual event stream and box it
320        let body = StreamBody::new(formatted_stream).boxed_unsync();
321
322        // Build response with proper SSE headers for streaming
323        Response::builder()
324            .status(StatusCode::OK)
325            .header(CONTENT_TYPE, "text/event-stream")
326            .header(CACHE_CONTROL, "no-cache")
327            .header(ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
328            .header("Connection", "keep-alive")
329            .body(body)
330            .unwrap()
331    }
332
333    /// Check if a session has any active SSE connections
334    pub async fn has_connections(&self, session_id: &str) -> bool {
335        let connections = self.connections.read().await;
336        connections.get(session_id)
337            .map(|session_connections| !session_connections.is_empty())
338            .unwrap_or(false)
339    }
340
341    /// Send event to specific session (MCP compliant - ONE connection only)
342    pub async fn broadcast_to_session(
343        &self,
344        session_id: &str,
345        event_type: String,
346        data: Value,
347    ) -> Result<u64, StreamError> {
348        self.broadcast_to_session_with_options(session_id, event_type, data, true).await
349    }
350
351    /// Send event to specific session with option to suppress when no connections exist
352    pub async fn broadcast_to_session_with_options(
353        &self,
354        session_id: &str,
355        event_type: String,
356        data: Value,
357        store_when_no_connections: bool,
358    ) -> Result<u64, StreamError> {
359        // Check if we should suppress notifications when no connections exist
360        if !store_when_no_connections && !self.has_connections(session_id).await {
361            debug!("๐Ÿšซ Suppressing notification for session {} (no connections, store_when_no_connections=false)", session_id);
362            return Err(StreamError::NoConnections(session_id.to_string()));
363        }
364
365        // Create the event
366        let event = SseEvent::new(event_type.clone(), data);
367
368        // Store event for resumability (always store for compliant clients)
369        let stored_event = self.storage.store_event(session_id, event).await
370            .map_err(|e| StreamError::StorageError(e.to_string()))?;
371
372        // DEBUG: Check connection state more thoroughly
373        let connections = self.connections.read().await;
374        debug!("[{}] ๐Ÿ” Checking connections for session {}: connections hashmap has {} sessions",
375               self.instance_id, session_id, connections.len());
376
377        if let Some(session_connections) = connections.get(session_id) {
378            debug!("๐Ÿ” Session {} found with {} connections", session_id, session_connections.len());
379
380            if !session_connections.is_empty() {
381                // Pick the FIRST available connection (MCP compliant)
382                let (selected_connection_id, selected_sender) = session_connections.iter().next().unwrap();
383
384                // Check if sender is closed
385                if selected_sender.is_closed() {
386                    warn!("๐Ÿ”Œ Sender is closed for connection: session={}, connection={}",
387                          session_id, selected_connection_id);
388                    debug!("๐Ÿ“ญ Connection sender was closed, event stored for reconnection");
389                } else {
390                    debug!("โœ… Sender is open, attempting to send to connection: session={}, connection={}",
391                           session_id, selected_connection_id);
392
393                    match selected_sender.try_send(stored_event.clone()) {
394                        Ok(()) => {
395                            info!("โœ… Sent notification to ONE connection: session={}, connection={}, event_id={}, method={}",
396                                  session_id, selected_connection_id, stored_event.id, stored_event.event_type);
397                        },
398                        Err(mpsc::error::TrySendError::Full(_)) => {
399                            warn!("โš ๏ธ Connection buffer full: session={}, connection={}", session_id, selected_connection_id);
400                            // Event is still stored for reconnection
401                        },
402                        Err(mpsc::error::TrySendError::Closed(_)) => {
403                            warn!("๐Ÿ”Œ Connection closed during send: session={}, connection={}", session_id, selected_connection_id);
404                            // Event is still stored for reconnection
405                        }
406                    }
407                }
408            } else {
409                debug!("๐Ÿ“ญ No active connections for session: {} (event stored for reconnection)", session_id);
410            }
411        } else {
412            debug!("๐Ÿ“ญ No connections registered for session: {} (event stored for reconnection)", session_id);
413
414            // DEBUG: List all sessions in connections
415            for (sid, conns) in connections.iter() {
416                debug!("๐Ÿ” Available session: {} with {} connections", sid, conns.len());
417            }
418        }
419
420        Ok(stored_event.id)
421    }
422
423    /// Broadcast to all sessions (for server-wide notifications)
424    pub async fn broadcast_to_all_sessions(
425        &self,
426        event_type: String,
427        data: Value,
428    ) -> Result<Vec<String>, StreamError> {
429        // Get all session IDs
430        let session_ids = self.storage.list_sessions().await
431            .map_err(|e| StreamError::StorageError(e.to_string()))?;
432
433        let mut failed_sessions = Vec::new();
434
435        for session_id in session_ids {
436            if let Err(e) = self.broadcast_to_session(&session_id, event_type.clone(), data.clone()).await {
437                error!("Failed to broadcast to session {}: {}", session_id, e);
438                failed_sessions.push(session_id);
439            }
440        }
441
442        Ok(failed_sessions)
443    }
444
445    /// Clean up closed connections
446    pub async fn cleanup_connections(&self) -> usize {
447        debug!("๐Ÿงน CLEANUP_CONNECTIONS called");
448        let mut connections = self.connections.write().await;
449        let mut total_cleaned = 0;
450
451        debug!("๐Ÿ” BEFORE cleanup: HashMap has {} sessions", connections.len());
452
453        // Clean up closed connections
454        connections.retain(|session_id, session_connections| {
455            let initial_count = session_connections.len();
456
457            // Remove closed connections
458            session_connections.retain(|connection_id, sender| {
459                if sender.is_closed() {
460                    debug!("๐Ÿงน Cleaned up closed connection: session={}, connection={}", session_id, connection_id);
461                    false
462                } else {
463                    true
464                }
465            });
466
467            let cleaned_count = initial_count - session_connections.len();
468            total_cleaned += cleaned_count;
469
470            // Keep session if it has active connections
471            !session_connections.is_empty()
472        });
473
474        if total_cleaned > 0 {
475            info!("Cleaned up {} inactive connections", total_cleaned);
476        }
477
478        total_cleaned
479    }
480
481    /// Create SSE stream for POST requests (MCP Streamable HTTP)
482    pub async fn create_post_sse_stream(
483        &self,
484        session_id: String,
485        response: turul_mcp_json_rpc_server::JsonRpcResponse,
486    ) -> Result<hyper::Response<http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>>, StreamError> {
487        // Verify session exists
488        if self.storage.get_session(&session_id).await
489            .map_err(|e| StreamError::StorageError(e.to_string()))?
490            .is_none()
491        {
492            return Err(StreamError::SessionNotFound(session_id));
493        }
494
495        info!("Creating POST SSE stream for session: {}", session_id);
496
497        // Create the SSE response body
498        let response_json = serde_json::to_string(&response)
499            .map_err(|e| StreamError::StorageError(format!("Failed to serialize response: {}", e)))?;
500
501        // 1. Include recent notifications that may have been generated during tool execution
502        // Note: Since tool notifications are processed asynchronously, we need to wait a moment
503        // and then check for recent events to include in the POST SSE response
504        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
505
506        let mut sse_frames = Vec::new();
507        let mut event_id_counter = 1;
508        
509        if let Ok(events) = self.storage.get_recent_events(&session_id, 10).await {
510            for event in events {
511                // Convert stored SSE event to notification JSON-RPC format
512                if event.event_type != "ping" { // Skip keepalive events
513                    let notification_sse = format!(
514                        "id: {}\nevent: {}\ndata: {}\n\n",
515                        event_id_counter,
516                        event.event_type, // Use actual event type (e.g., "notifications/message")
517                        event.data
518                    );
519                    debug!("๐Ÿ“ค Including notification in POST SSE stream: id={}, event_type={}", event_id_counter, event.event_type);
520                    sse_frames.push(http_body::Frame::data(Bytes::from(notification_sse)));
521                    event_id_counter += 1;
522                }
523            }
524        }
525
526        // 2. Add the JSON-RPC tool response
527        let response_sse = format!(
528            "id: {}\nevent: result\ndata: {}\n\n", // Tool responses use "result" event type
529            event_id_counter,
530            response_json
531        );
532        debug!("๐Ÿ“ค Sending JSON-RPC response as SSE event: id={}, event=result", event_id_counter);
533        sse_frames.push(http_body::Frame::data(Bytes::from(response_sse)));
534
535        // Create a simple stream from the collected frames
536        let stream = futures::stream::iter(sse_frames.into_iter().map(Ok::<_, std::convert::Infallible>));
537
538        // Create StreamBody from the stream and box it for type erasure
539        let body = StreamBody::new(stream);
540        let boxed_body = http_body_util::combinators::BoxBody::new(body);
541
542        debug!("๐Ÿ“ก POST SSE streaming response created: session={}", session_id);
543
544        // Build response with proper SSE headers including MCP session ID
545        Ok(hyper::Response::builder()
546            .status(hyper::StatusCode::OK)
547            .header(hyper::header::CONTENT_TYPE, "text/event-stream")
548            .header(hyper::header::CACHE_CONTROL, "no-cache")
549            .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
550            .header("Connection", "keep-alive")
551            .header("X-Accel-Buffering", "no") // Prevent proxy buffering
552            .header("Mcp-Session-Id", &session_id)
553            .body(boxed_body)
554            .unwrap())
555    }
556
557    /// Get statistics about active streams
558    pub async fn get_stats(&self) -> StreamStats {
559        let connections = self.connections.read().await;
560        let session_count = self.storage.session_count().await.unwrap_or(0);
561        let event_count = self.storage.event_count().await.unwrap_or(0);
562
563        // Count total active connections
564        let total_connections: usize = connections.values()
565            .map(|session_connections| session_connections.len())
566            .sum();
567
568        StreamStats {
569            active_broadcasters: total_connections, // Now tracks active connections
570            total_sessions: session_count,
571            total_events: event_count,
572            channel_buffer_size: self.config.channel_buffer_size,
573        }
574    }
575}
576
577impl Drop for StreamManager {
578    fn drop(&mut self) {
579        debug!("๐Ÿ”ฅ DROP: StreamManager instance {} - this may cause connection loss!",
580               self.instance_id);
581        debug!("๐Ÿ”ฅ If this appears during request processing, it indicates architecture problem");
582    }
583}
584
585/// Stream manager statistics
586#[derive(Debug, Clone)]
587pub struct StreamStats {
588    pub active_broadcasters: usize,
589    pub total_sessions: usize,
590    pub total_events: usize,
591    pub channel_buffer_size: usize,
592}
593
594// Helper to create async stream
595#[cfg(not(test))]
596use async_stream;
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use turul_mcp_session_storage::{InMemorySessionStorage, SessionStorage};
602    use turul_mcp_protocol::ServerCapabilities;
603
604    #[tokio::test]
605    async fn test_stream_manager_creation() {
606        let storage = Arc::new(InMemorySessionStorage::new());
607        let manager = StreamManager::new(storage);
608
609        let stats = manager.get_stats().await;
610        assert_eq!(stats.active_broadcasters, 0);
611        assert_eq!(stats.total_sessions, 0);
612    }
613
614    #[tokio::test]
615    async fn test_broadcast_to_session() {
616        let storage = Arc::new(InMemorySessionStorage::new());
617        let manager = StreamManager::new(storage.clone());
618
619        // Create a session
620        let session = storage.create_session(ServerCapabilities::default()).await.unwrap();
621        let session_id = session.session_id.clone();
622
623        // Broadcast an event
624        let event_id = manager.broadcast_to_session(
625            &session_id,
626            "test".to_string(),
627            serde_json::json!({"message": "test"})
628        ).await.unwrap();
629
630        assert!(event_id > 0);
631
632        // Verify event was stored
633        let events = storage.get_events_after(&session_id, 0).await.unwrap();
634        assert_eq!(events.len(), 1);
635        assert_eq!(events[0].id, event_id);
636    }
637}