turul_http_mcp_server/
stream_manager.rs

1//! Enhanced Stream Manager with MCP 2025-06-18 Resumability
2//!
3//! This module provides proper SSE stream management with:
4//! - Event IDs for resumability
5//! - Last-Event-ID header support
6//! - Per-session event targeting (not broadcast to all)
7//! - Event persistence and replay
8//! - Proper HTTP status codes and headers
9
10use bytes::Bytes;
11use futures::{Stream, StreamExt};
12use http_body_util::{BodyExt, StreamBody};
13use hyper::header::{ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE};
14use hyper::{Response, StatusCode};
15use serde_json::Value;
16use std::collections::{HashMap, HashSet};
17use std::pin::Pin;
18use std::sync::Arc;
19use tokio::sync::{RwLock, mpsc};
20use tracing::{debug, error, warn};
21
22use turul_mcp_session_storage::SseEvent;
23
24/// Connection ID for tracking individual SSE streams
25pub type ConnectionId = String;
26pub type SessionConnections = HashMap<ConnectionId, mpsc::Sender<SseEvent>>;
27pub type ConnectionsMap = Arc<RwLock<HashMap<String, SessionConnections>>>;
28
29/// Enhanced stream manager with resumability support (MCP spec compliant)
30pub struct StreamManager {
31    /// Session storage backend for persistence
32    storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
33    /// Per-session connections for real-time events (MCP compliant - no broadcasting)
34    connections: ConnectionsMap,
35    /// Per-session notification subscriptions (what notifications each session wants)
36    subscriptions: Arc<RwLock<HashMap<String, HashSet<String>>>>,
37    /// Configuration
38    config: StreamConfig,
39    /// Unique instance ID for debugging
40    instance_id: String,
41}
42
43/// Configuration for stream management
44#[derive(Debug, Clone)]
45pub struct StreamConfig {
46    /// Channel buffer size for real-time broadcasting
47    pub channel_buffer_size: usize,
48    /// Maximum events to replay on reconnection
49    pub max_replay_events: usize,
50    /// Keep-alive interval in seconds
51    pub keepalive_interval_seconds: u64,
52    /// CORS configuration
53    pub cors_origin: String,
54}
55
56impl Default for StreamConfig {
57    fn default() -> Self {
58        Self {
59            channel_buffer_size: 1000,
60            max_replay_events: 100,
61            keepalive_interval_seconds: 30,
62            cors_origin: "*".to_string(),
63        }
64    }
65}
66
67/// SSE stream wrapper that formats events properly (MCP compliant - one connection per stream)
68pub struct SseStream {
69    /// Underlying event stream
70    stream: Option<Pin<Box<dyn Stream<Item = SseEvent> + Send>>>,
71    /// Session metadata
72    session_id: String,
73    /// Connection identifier (for MCP spec compliance)
74    connection_id: ConnectionId,
75}
76
77impl SseStream {
78    /// Get the session ID this stream belongs to
79    pub fn session_id(&self) -> &str {
80        &self.session_id
81    }
82
83    /// Get the connection ID for this stream
84    pub fn connection_id(&self) -> &str {
85        &self.connection_id
86    }
87
88    /// Get stream identifier for logging (session + connection)
89    pub fn stream_identifier(&self) -> String {
90        format!("{}:{}", self.session_id, self.connection_id)
91    }
92}
93
94impl Drop for SseStream {
95    fn drop(&mut self) {
96        debug!(
97            "DROP: SseStream - session={}, connection={}",
98            self.session_id, self.connection_id
99        );
100        if self.stream.is_some() {
101            debug!("Stream still present during drop - this indicates early cleanup");
102        } else {
103            debug!("Stream was properly extracted before drop");
104        }
105    }
106}
107
108/// Error type for stream management
109#[derive(Debug, thiserror::Error)]
110pub enum StreamError {
111    #[error("Session not found: {0}")]
112    SessionNotFound(String),
113    #[error("Stream not found: session={0}, stream={1}")]
114    StreamNotFound(String, String),
115    #[error("Storage error: {0}")]
116    StorageError(String),
117    #[error("Connection error: {0}")]
118    ConnectionError(String),
119    #[error("No connections available for session: {0}")]
120    NoConnections(String),
121    #[error("Session {0} not subscribed to notification type: {1}")]
122    NotSubscribed(String, String),
123}
124
125impl StreamManager {
126    /// Create new stream manager with session storage backend
127    pub fn new(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>) -> Self {
128        Self::with_config(storage, StreamConfig::default())
129    }
130
131    /// Create stream manager with custom configuration
132    pub fn with_config(
133        storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
134        config: StreamConfig,
135    ) -> Self {
136        use uuid::Uuid;
137        let instance_id = Uuid::now_v7().to_string();
138        debug!("Creating StreamManager instance: {}", instance_id);
139        Self {
140            storage,
141            connections: Arc::new(RwLock::new(HashMap::new())),
142            subscriptions: Arc::new(RwLock::new(HashMap::new())),
143            config,
144            instance_id,
145        }
146    }
147
148    /// Handle SSE connection request with proper resumability
149    pub async fn handle_sse_connection(
150        &self,
151        session_id: String,
152        connection_id: ConnectionId,
153        last_event_id: Option<u64>,
154    ) -> Result<
155        Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>,
156        StreamError,
157    > {
158        // Verify session exists
159        if self
160            .storage
161            .get_session(&session_id)
162            .await
163            .map_err(|e| StreamError::StorageError(e.to_string()))?
164            .is_none()
165        {
166            return Err(StreamError::SessionNotFound(session_id));
167        }
168
169        // Create the SSE stream (one per connection, MCP compliant)
170        let sse_stream = self
171            .create_sse_stream(session_id.clone(), connection_id.clone(), last_event_id)
172            .await?;
173
174        // Convert to HTTP response
175        let response = self.stream_to_response(sse_stream).await;
176
177        debug!(
178            "Created SSE connection: session={}, connection={}, last_event_id={:?}",
179            session_id, connection_id, last_event_id
180        );
181
182        Ok(response)
183    }
184
185    /// Create SSE stream with resumability support (MCP compliant - no broadcast)
186    async fn create_sse_stream(
187        &self,
188        session_id: String,
189        connection_id: ConnectionId,
190        last_event_id: Option<u64>,
191    ) -> Result<SseStream, StreamError> {
192        // Create mpsc channel for this specific connection (MCP compliant)
193        let (sender, mut receiver) = mpsc::channel(self.config.channel_buffer_size);
194
195        // Register this connection with the session
196        self.register_connection(&session_id, connection_id.clone(), sender)
197            .await;
198
199        // Create the combined stream
200        let storage = self.storage.clone();
201        let session_id_clone = session_id.clone();
202        let connection_id_clone = connection_id.clone();
203        let config = self.config.clone();
204
205        let combined_stream = async_stream::stream! {
206            // 1. First, yield any historical events (resumability)
207            if let Some(after_event_id) = last_event_id {
208                debug!("Replaying events after ID {} for session={}, connection={}",
209                       after_event_id, session_id_clone, connection_id_clone);
210
211                match storage.get_events_after(&session_id_clone, after_event_id).await {
212                    Ok(events) => {
213                        for event in events.into_iter().take(config.max_replay_events) {
214                            yield event;
215                        }
216                    },
217                    Err(e) => {
218                        error!("Failed to get historical events: {}", e);
219                        // Continue with real-time events even if historical replay fails
220                    }
221                }
222            }
223
224            // 2. Then, stream real-time events from dedicated channel
225            let mut keepalive_interval = tokio::time::interval(
226                tokio::time::Duration::from_secs(config.keepalive_interval_seconds)
227            );
228
229            loop {
230                tokio::select! {
231                    // Real-time events from this connection's channel
232                    event = receiver.recv() => {
233                        match event {
234                            Some(event) => {
235                                debug!("Received event for connection {}: {}", connection_id_clone, event.event_type);
236                                yield event;
237                            },
238                            None => {
239                                debug!("Connection channel closed for session={}, connection={}", session_id_clone, connection_id_clone);
240                                break;
241                            }
242                        }
243                    },
244
245                    // Keep-alive pings
246                    _ = keepalive_interval.tick() => {
247                        let keepalive_event = SseEvent {
248                            id: 0, // Keep-alive events don't need persistent IDs
249                            timestamp: chrono::Utc::now().timestamp_millis() as u64,
250                            event_type: "ping".to_string(),
251                            data: serde_json::json!({"type": "keepalive"}),
252                            retry: None,
253                        };
254                        yield keepalive_event;
255                    }
256                }
257            }
258
259            // Clean up connection when stream ends
260            debug!("Cleaning up connection: session={}, connection={}", session_id_clone, connection_id_clone);
261        };
262
263        Ok(SseStream {
264            stream: Some(Box::pin(combined_stream)),
265            session_id,
266            connection_id,
267        })
268    }
269
270    /// Register a new connection for a session (MCP compliant)
271    async fn register_connection(
272        &self,
273        session_id: &str,
274        connection_id: ConnectionId,
275        sender: mpsc::Sender<SseEvent>,
276    ) {
277        let mut connections = self.connections.write().await;
278
279        debug!(
280            "[{}] ๐Ÿ” BEFORE registration: HashMap has {} sessions",
281            self.instance_id,
282            connections.len()
283        );
284        for (sid, conns) in connections.iter() {
285            debug!(
286                "[{}] ๐Ÿ” Existing session before: {} with {} connections",
287                self.instance_id,
288                sid,
289                conns.len()
290            );
291        }
292
293        // Get or create session entry
294        let session_connections = connections
295            .entry(session_id.to_string())
296            .or_insert_with(HashMap::new);
297
298        // Add this connection
299        session_connections.insert(connection_id.clone(), sender);
300
301        debug!(
302            "[{}] ๐Ÿ”— Registered connection: session={}, connection={}, total_connections={}",
303            self.instance_id,
304            session_id,
305            connection_id,
306            session_connections.len()
307        );
308
309        debug!(
310            "[{}] ๐Ÿ” AFTER registration: HashMap has {} sessions",
311            self.instance_id,
312            connections.len()
313        );
314        for (sid, conns) in connections.iter() {
315            debug!(
316                "[{}] ๐Ÿ” Session after: {} with {} connections",
317                self.instance_id,
318                sid,
319                conns.len()
320            );
321        }
322    }
323
324    /// Register a streaming connection to receive events for a session (public API for POST streaming)
325    pub async fn register_streaming_connection(
326        &self,
327        session_id: &str,
328        connection_id: ConnectionId,
329        sender: mpsc::Sender<SseEvent>,
330    ) -> Result<(), StreamError> {
331        // Verify session exists first
332        if self
333            .storage
334            .get_session(session_id)
335            .await
336            .map_err(|e| StreamError::StorageError(e.to_string()))?
337            .is_none()
338        {
339            return Err(StreamError::SessionNotFound(session_id.to_string()));
340        }
341
342        self.register_connection(session_id, connection_id, sender)
343            .await;
344        Ok(())
345    }
346
347    /// Remove a connection when it's closed
348    pub async fn unregister_connection(&self, session_id: &str, connection_id: &ConnectionId) {
349        debug!(
350            "๐Ÿ”ด UNREGISTER called for session={}, connection={}",
351            session_id, connection_id
352        );
353        let mut connections = self.connections.write().await;
354
355        debug!(
356            "๐Ÿ” BEFORE unregister: HashMap has {} sessions",
357            connections.len()
358        );
359
360        if let Some(session_connections) = connections.get_mut(session_id)
361            && session_connections.remove(connection_id).is_some()
362        {
363            debug!(
364                "๐Ÿ”Œ Unregistered connection: session={}, connection={}",
365                session_id, connection_id
366            );
367
368            // Clean up empty sessions
369            if session_connections.is_empty() {
370                connections.remove(session_id);
371                debug!("๐Ÿงน Removed empty session: {}", session_id);
372            }
373        }
374
375        debug!(
376            "๐Ÿ” AFTER unregister: HashMap has {} sessions",
377            connections.len()
378        );
379    }
380
381    /// Close all SSE connections for a session (useful for session termination)
382    pub async fn close_session_connections(&self, session_id: &str) -> usize {
383        debug!("๐Ÿ”ด Closing all connections for session: {}", session_id);
384        let mut connections = self.connections.write().await;
385
386        let closed_count = if let Some(session_connections) = connections.remove(session_id) {
387            let count = session_connections.len();
388            debug!(
389                "๐Ÿ”Œ Closed {} SSE connections for session: {}",
390                count, session_id
391            );
392            count
393        } else {
394            debug!("๐Ÿ” No SSE connections found for session: {}", session_id);
395            0
396        };
397
398        // Also clear subscriptions for this session
399        self.clear_subscriptions(session_id).await;
400
401        debug!("๐Ÿงน Session {} removed from stream manager", session_id);
402        closed_count
403    }
404
405    /// Convert SSE stream to HTTP response with proper headers
406    async fn stream_to_response(
407        &self,
408        mut sse_stream: SseStream,
409    ) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
410        // Extract session info before moving the stream
411        let session_id = sse_stream.session_id().to_string();
412        let stream_identifier = sse_stream.stream_identifier();
413
414        // Log stream creation with session identifier
415        debug!(
416            "Converting SSE stream to HTTP response: {}",
417            stream_identifier
418        );
419        debug!("Stream details: session_id={}", session_id);
420
421        // Transform events to SSE format and create proper HTTP frames
422        // Extract stream from Option wrapper
423        let stream = sse_stream
424            .stream
425            .take()
426            .expect("Stream should be present in SseStream");
427
428        let formatted_stream = stream.map(|event| {
429            let sse_formatted = event.format();
430            debug!(
431                "๐Ÿ“ก Streaming SSE event: id={}, event_type={}",
432                event.id, event.event_type
433            );
434            Ok(hyper::body::Frame::data(Bytes::from(sse_formatted)))
435        });
436
437        // Create streaming body from the actual event stream and box it
438        let body = StreamBody::new(formatted_stream).boxed_unsync();
439
440        // Build response with proper SSE headers for streaming
441        Response::builder()
442            .status(StatusCode::OK)
443            .header(CONTENT_TYPE, "text/event-stream")
444            .header(CACHE_CONTROL, "no-cache")
445            .header(ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
446            .header("Connection", "keep-alive")
447            .body(body)
448            .unwrap()
449    }
450
451    /// Check if a session has any active SSE connections
452    pub async fn has_connections(&self, session_id: &str) -> bool {
453        let connections = self.connections.read().await;
454        connections
455            .get(session_id)
456            .map(|session_connections| !session_connections.is_empty())
457            .unwrap_or(false)
458    }
459
460    /// Send event to specific session (MCP compliant - ONE connection only)
461    pub async fn broadcast_to_session(
462        &self,
463        session_id: &str,
464        event_type: String,
465        data: Value,
466    ) -> Result<u64, StreamError> {
467        self.broadcast_to_session_with_options(session_id, event_type, data, true)
468            .await
469    }
470
471    /// Send event to specific session with option to suppress when no connections exist
472    pub async fn broadcast_to_session_with_options(
473        &self,
474        session_id: &str,
475        event_type: String,
476        data: Value,
477        store_when_no_connections: bool,
478    ) -> Result<u64, StreamError> {
479        // Check subscription filtering first
480        if !self.is_subscribed(session_id, &event_type).await {
481            debug!(
482                "๐Ÿšซ Session {} not subscribed to notification type: {}",
483                session_id, event_type
484            );
485            return Err(StreamError::NotSubscribed(
486                session_id.to_string(),
487                event_type,
488            ));
489        }
490
491        // Check if we should suppress notifications when no connections exist
492        if !store_when_no_connections && !self.has_connections(session_id).await {
493            debug!(
494                "๐Ÿšซ Suppressing notification for session {} (no connections, store_when_no_connections=false)",
495                session_id
496            );
497            return Err(StreamError::NoConnections(session_id.to_string()));
498        }
499
500        // Create the event
501        let event = SseEvent::new(event_type.clone(), data);
502
503        // Store event for resumability (always store for compliant clients)
504        let stored_event = self
505            .storage
506            .store_event(session_id, event)
507            .await
508            .map_err(|e| StreamError::StorageError(e.to_string()))?;
509
510        // DEBUG: Check connection state more thoroughly
511        let connections = self.connections.read().await;
512        debug!(
513            "[{}] ๐Ÿ” Checking connections for session {}: connections hashmap has {} sessions",
514            self.instance_id,
515            session_id,
516            connections.len()
517        );
518
519        if let Some(session_connections) = connections.get(session_id) {
520            debug!(
521                "๐Ÿ” Session {} found with {} connections",
522                session_id,
523                session_connections.len()
524            );
525
526            if !session_connections.is_empty() {
527                // Pick the FIRST available connection (MCP compliant)
528                let (selected_connection_id, selected_sender) =
529                    session_connections.iter().next().unwrap();
530
531                // Check if sender is closed
532                if selected_sender.is_closed() {
533                    warn!(
534                        "๐Ÿ”Œ Sender is closed for connection: session={}, connection={}",
535                        session_id, selected_connection_id
536                    );
537                    debug!("๐Ÿ“ญ Connection sender was closed, event stored for reconnection");
538                } else {
539                    debug!(
540                        "โœ… Sender is open, attempting to send to connection: session={}, connection={}",
541                        session_id, selected_connection_id
542                    );
543
544                    match selected_sender.try_send(stored_event.clone()) {
545                        Ok(()) => {
546                            debug!(
547                                "Sent notification to ONE connection: session={}, connection={}, event_id={}, method={}",
548                                session_id,
549                                selected_connection_id,
550                                stored_event.id,
551                                stored_event.event_type
552                            );
553                        }
554                        Err(mpsc::error::TrySendError::Full(_)) => {
555                            warn!(
556                                "โš ๏ธ Connection buffer full: session={}, connection={}",
557                                session_id, selected_connection_id
558                            );
559                            // Event is still stored for reconnection
560                        }
561                        Err(mpsc::error::TrySendError::Closed(_)) => {
562                            warn!(
563                                "๐Ÿ”Œ Connection closed during send: session={}, connection={}",
564                                session_id, selected_connection_id
565                            );
566                            // Event is still stored for reconnection
567                        }
568                    }
569                }
570            } else {
571                debug!(
572                    "๐Ÿ“ญ No active connections for session: {} (event stored for reconnection)",
573                    session_id
574                );
575            }
576        } else {
577            debug!(
578                "๐Ÿ“ญ No connections registered for session: {} (event stored for reconnection)",
579                session_id
580            );
581
582            // DEBUG: List all sessions in connections
583            for (sid, conns) in connections.iter() {
584                debug!(
585                    "๐Ÿ” Available session: {} with {} connections",
586                    sid,
587                    conns.len()
588                );
589            }
590        }
591
592        Ok(stored_event.id)
593    }
594
595    /// Broadcast to all sessions (for server-wide notifications)
596    pub async fn broadcast_to_all_sessions(
597        &self,
598        event_type: String,
599        data: Value,
600    ) -> Result<Vec<String>, StreamError> {
601        // Get all session IDs
602        let session_ids = self
603            .storage
604            .list_sessions()
605            .await
606            .map_err(|e| StreamError::StorageError(e.to_string()))?;
607
608        let mut failed_sessions = Vec::new();
609
610        for session_id in session_ids {
611            if let Err(e) = self
612                .broadcast_to_session(&session_id, event_type.clone(), data.clone())
613                .await
614            {
615                error!("Failed to broadcast to session {}: {}", session_id, e);
616                failed_sessions.push(session_id);
617            }
618        }
619
620        Ok(failed_sessions)
621    }
622
623    /// Clean up closed connections
624    pub async fn cleanup_connections(&self) -> usize {
625        debug!("๐Ÿงน CLEANUP_CONNECTIONS called");
626        let mut connections = self.connections.write().await;
627        let mut total_cleaned = 0;
628
629        debug!(
630            "๐Ÿ” BEFORE cleanup: HashMap has {} sessions",
631            connections.len()
632        );
633
634        // Clean up closed connections
635        connections.retain(|session_id, session_connections| {
636            let initial_count = session_connections.len();
637
638            // Remove closed connections
639            session_connections.retain(|connection_id, sender| {
640                if sender.is_closed() {
641                    debug!(
642                        "๐Ÿงน Cleaned up closed connection: session={}, connection={}",
643                        session_id, connection_id
644                    );
645                    false
646                } else {
647                    true
648                }
649            });
650
651            let cleaned_count = initial_count - session_connections.len();
652            total_cleaned += cleaned_count;
653
654            // Keep session if it has active connections
655            !session_connections.is_empty()
656        });
657
658        if total_cleaned > 0 {
659            debug!("Cleaned up {} inactive connections", total_cleaned);
660        }
661
662        total_cleaned
663    }
664
665    /// Create SSE stream for POST requests (MCP Streamable HTTP)
666    pub async fn create_post_sse_stream(
667        &self,
668        session_id: String,
669        response: turul_mcp_json_rpc_server::JsonRpcResponse,
670    ) -> Result<
671        hyper::Response<
672            http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>,
673        >,
674        StreamError,
675    > {
676        // Verify session exists
677        if self
678            .storage
679            .get_session(&session_id)
680            .await
681            .map_err(|e| StreamError::StorageError(e.to_string()))?
682            .is_none()
683        {
684            return Err(StreamError::SessionNotFound(session_id));
685        }
686
687        debug!("Creating POST SSE stream for session: {}", session_id);
688
689        // Create the SSE response body
690        let response_json = serde_json::to_string(&response).map_err(|e| {
691            StreamError::StorageError(format!("Failed to serialize response: {}", e))
692        })?;
693
694        // 1. Include recent notifications that may have been generated during tool execution
695        // Note: Since tool notifications are processed asynchronously, we need to wait a moment
696        // and then check for recent events to include in the POST SSE response
697        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
698
699        let mut sse_frames = Vec::new();
700        let mut event_id_counter = 1;
701
702        if let Ok(events) = self.storage.get_recent_events(&session_id, 10).await {
703            for event in events {
704                // Convert stored SSE event to notification JSON-RPC format
705                if event.event_type != "ping" {
706                    // Skip keepalive events
707                    let notification_sse = format!(
708                        "id: {}\nevent: {}\ndata: {}\n\n",
709                        event_id_counter,
710                        event.event_type, // Use actual event type (e.g., "notifications/message")
711                        event.data
712                    );
713                    debug!(
714                        "๐Ÿ“ค Including notification in POST SSE stream: id={}, event_type={}",
715                        event_id_counter, event.event_type
716                    );
717                    sse_frames.push(http_body::Frame::data(Bytes::from(notification_sse)));
718                    event_id_counter += 1;
719                }
720            }
721        }
722
723        // 2. Add the JSON-RPC tool response
724        let response_sse = format!(
725            "id: {}\nevent: result\ndata: {}\n\n", // Tool responses use "result" event type
726            event_id_counter, response_json
727        );
728        debug!(
729            "๐Ÿ“ค Sending JSON-RPC response as SSE event: id={}, event=result",
730            event_id_counter
731        );
732        sse_frames.push(http_body::Frame::data(Bytes::from(response_sse)));
733
734        // Create a simple stream from the collected frames
735        let stream = futures::stream::iter(
736            sse_frames
737                .into_iter()
738                .map(Ok::<_, std::convert::Infallible>),
739        );
740
741        // Create StreamBody from the stream and box it for type erasure
742        let body = StreamBody::new(stream);
743        let boxed_body = http_body_util::combinators::BoxBody::new(body);
744
745        debug!(
746            "๐Ÿ“ก POST SSE streaming response created: session={}",
747            session_id
748        );
749
750        // Build response with proper SSE headers including MCP session ID
751        Ok(hyper::Response::builder()
752            .status(hyper::StatusCode::OK)
753            .header(hyper::header::CONTENT_TYPE, "text/event-stream")
754            .header(hyper::header::CACHE_CONTROL, "no-cache")
755            .header(
756                hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
757                &self.config.cors_origin,
758            )
759            .header("Connection", "keep-alive")
760            .header("X-Accel-Buffering", "no") // Prevent proxy buffering
761            .header("Mcp-Session-Id", &session_id)
762            .body(boxed_body)
763            .unwrap())
764    }
765
766    /// Subscribe a session to specific notification types
767    pub async fn subscribe_to_notifications(
768        &self,
769        session_id: &str,
770        notification_types: Vec<String>,
771    ) {
772        let mut subscriptions = self.subscriptions.write().await;
773        let session_subscriptions = subscriptions
774            .entry(session_id.to_string())
775            .or_insert_with(HashSet::new);
776
777        for notification_type in notification_types {
778            session_subscriptions.insert(notification_type.clone());
779            debug!(
780                "๐Ÿ“ Session {} subscribed to notification: {}",
781                session_id, notification_type
782            );
783        }
784
785        debug!(
786            "Session {} now has {} subscriptions",
787            session_id,
788            session_subscriptions.len()
789        );
790    }
791
792    /// Unsubscribe a session from specific notification types
793    pub async fn unsubscribe_from_notifications(
794        &self,
795        session_id: &str,
796        notification_types: Vec<String>,
797    ) {
798        let mut subscriptions = self.subscriptions.write().await;
799        if let Some(session_subscriptions) = subscriptions.get_mut(session_id) {
800            for notification_type in notification_types {
801                if session_subscriptions.remove(&notification_type) {
802                    debug!(
803                        "๐Ÿ“ Session {} unsubscribed from notification: {}",
804                        session_id, notification_type
805                    );
806                }
807            }
808
809            // Remove session entry if no subscriptions remain
810            if session_subscriptions.is_empty() {
811                subscriptions.remove(session_id);
812                debug!(
813                    "๐Ÿ—‘๏ธ Removed subscription entry for session {} (no remaining subscriptions)",
814                    session_id
815                );
816            }
817        }
818    }
819
820    /// Check if a session is subscribed to a specific notification type
821    pub async fn is_subscribed(&self, session_id: &str, notification_type: &str) -> bool {
822        let subscriptions = self.subscriptions.read().await;
823        subscriptions
824            .get(session_id)
825            .map(|session_subscriptions| session_subscriptions.contains(notification_type))
826            .unwrap_or(true) // Default: allow all notifications if no explicit subscriptions
827    }
828
829    /// Get all subscriptions for a session
830    pub async fn get_subscriptions(&self, session_id: &str) -> HashSet<String> {
831        let subscriptions = self.subscriptions.read().await;
832        subscriptions.get(session_id).cloned().unwrap_or_default()
833    }
834
835    /// Clear all subscriptions for a session (used during session cleanup)
836    pub async fn clear_subscriptions(&self, session_id: &str) {
837        let mut subscriptions = self.subscriptions.write().await;
838        if subscriptions.remove(session_id).is_some() {
839            debug!("๐Ÿ—‘๏ธ Cleared all subscriptions for session: {}", session_id);
840        }
841    }
842
843    /// Get the stream configuration (for testing and debugging)
844    pub fn get_config(&self) -> &StreamConfig {
845        &self.config
846    }
847
848    /// Get statistics about active streams
849    pub async fn get_stats(&self) -> StreamStats {
850        let connections = self.connections.read().await;
851        let session_count = self.storage.session_count().await.unwrap_or(0);
852        let event_count = self.storage.event_count().await.unwrap_or(0);
853
854        // Count total active connections
855        let total_connections: usize = connections
856            .values()
857            .map(|session_connections| session_connections.len())
858            .sum();
859
860        StreamStats {
861            active_broadcasters: total_connections, // Now tracks active connections
862            total_sessions: session_count,
863            total_events: event_count,
864            channel_buffer_size: self.config.channel_buffer_size,
865        }
866    }
867}
868
869impl Drop for StreamManager {
870    fn drop(&mut self) {
871        debug!(
872            "DROP: StreamManager instance {} - this may cause connection loss!",
873            self.instance_id
874        );
875        debug!("If this appears during request processing, it indicates architecture problem");
876    }
877}
878
879/// Stream manager statistics
880#[derive(Debug, Clone)]
881pub struct StreamStats {
882    pub active_broadcasters: usize,
883    pub total_sessions: usize,
884    pub total_events: usize,
885    pub channel_buffer_size: usize,
886}
887
888// Helper to create async stream
889#[cfg(not(test))]
890use async_stream;
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895    use turul_mcp_protocol::ServerCapabilities;
896    use turul_mcp_session_storage::{InMemorySessionStorage, SessionStorage};
897
898    #[tokio::test]
899    async fn test_stream_manager_creation() {
900        let storage = Arc::new(InMemorySessionStorage::new());
901        let manager = StreamManager::new(storage);
902
903        let stats = manager.get_stats().await;
904        assert_eq!(stats.active_broadcasters, 0);
905        assert_eq!(stats.total_sessions, 0);
906    }
907
908    #[tokio::test]
909    async fn test_broadcast_to_session() {
910        let storage = Arc::new(InMemorySessionStorage::new());
911        let manager = StreamManager::new(storage.clone());
912
913        // Create a session
914        let session = storage
915            .create_session(ServerCapabilities::default())
916            .await
917            .unwrap();
918        let session_id = session.session_id.clone();
919
920        // Broadcast an event
921        let event_id = manager
922            .broadcast_to_session(
923                &session_id,
924                "test".to_string(),
925                serde_json::json!({"message": "test"}),
926            )
927            .await
928            .unwrap();
929
930        assert!(event_id > 0);
931
932        // Verify event was stored
933        let events = storage.get_events_after(&session_id, 0).await.unwrap();
934        assert_eq!(events.len(), 1);
935        assert_eq!(events[0].id, event_id);
936    }
937}