1use std::sync::Arc;
11use std::collections::HashMap;
12use std::pin::Pin;
13use futures::{Stream, StreamExt};
14use hyper::{Response, StatusCode};
15use http_body_util::{StreamBody, BodyExt};
16use bytes::Bytes;
17use hyper::header::{CONTENT_TYPE, CACHE_CONTROL, ACCESS_CONTROL_ALLOW_ORIGIN};
18use serde_json::Value;
19use tokio::sync::{mpsc, RwLock};
20use tracing::{debug, info, error, warn};
21
22use turul_mcp_session_storage::SseEvent;
23
24pub type ConnectionId = String;
26
27pub struct StreamManager {
29 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
31 connections: Arc<RwLock<HashMap<String, HashMap<ConnectionId, mpsc::Sender<SseEvent>>>>>,
33 config: StreamConfig,
35 instance_id: String,
37}
38
39#[derive(Debug, Clone)]
41pub struct StreamConfig {
42 pub channel_buffer_size: usize,
44 pub max_replay_events: usize,
46 pub keepalive_interval_seconds: u64,
48 pub cors_origin: String,
50}
51
52impl Default for StreamConfig {
53 fn default() -> Self {
54 Self {
55 channel_buffer_size: 1000,
56 max_replay_events: 100,
57 keepalive_interval_seconds: 30,
58 cors_origin: "*".to_string(),
59 }
60 }
61}
62
63pub struct SseStream {
65 stream: Option<Pin<Box<dyn Stream<Item = SseEvent> + Send>>>,
67 session_id: String,
69 connection_id: ConnectionId,
71}
72
73impl SseStream {
74 pub fn session_id(&self) -> &str {
76 &self.session_id
77 }
78
79 pub fn connection_id(&self) -> &str {
81 &self.connection_id
82 }
83
84 pub fn stream_identifier(&self) -> String {
86 format!("{}:{}", self.session_id, self.connection_id)
87 }
88}
89
90impl Drop for SseStream {
91 fn drop(&mut self) {
92 debug!("๐ฅ DROP: SseStream - session={}, connection={}",
93 self.session_id, self.connection_id);
94 if self.stream.is_some() {
95 debug!("๐ฅ Stream still present during drop - this indicates early cleanup");
96 } else {
97 debug!("๐ฅ Stream was properly extracted before drop");
98 }
99 }
100}
101
102#[derive(Debug, thiserror::Error)]
104pub enum StreamError {
105 #[error("Session not found: {0}")]
106 SessionNotFound(String),
107 #[error("Stream not found: session={0}, stream={1}")]
108 StreamNotFound(String, String),
109 #[error("Storage error: {0}")]
110 StorageError(String),
111 #[error("Connection error: {0}")]
112 ConnectionError(String),
113 #[error("No connections available for session: {0}")]
114 NoConnections(String),
115}
116
117impl StreamManager {
118 pub fn new(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>) -> Self {
120 Self::with_config(storage, StreamConfig::default())
121 }
122
123 pub fn with_config(storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>, config: StreamConfig) -> Self {
125 use uuid::Uuid;
126 let instance_id = Uuid::now_v7().to_string();
127 debug!("๐ง Creating StreamManager instance: {}", instance_id);
128 Self {
129 storage,
130 connections: Arc::new(RwLock::new(HashMap::new())),
131 config,
132 instance_id,
133 }
134 }
135
136 pub async fn handle_sse_connection(
138 &self,
139 session_id: String,
140 connection_id: ConnectionId,
141 last_event_id: Option<u64>,
142 ) -> Result<Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>>, StreamError> {
143 if self.storage.get_session(&session_id).await
145 .map_err(|e| StreamError::StorageError(e.to_string()))?
146 .is_none()
147 {
148 return Err(StreamError::SessionNotFound(session_id));
149 }
150
151 let sse_stream = self.create_sse_stream(session_id.clone(), connection_id.clone(), last_event_id).await?;
153
154 let response = self.stream_to_response(sse_stream).await;
156
157 info!("Created SSE connection: session={}, connection={}, last_event_id={:?}",
158 session_id, connection_id, last_event_id);
159
160 Ok(response)
161 }
162
163 async fn create_sse_stream(
165 &self,
166 session_id: String,
167 connection_id: ConnectionId,
168 last_event_id: Option<u64>,
169 ) -> Result<SseStream, StreamError> {
170 let (sender, mut receiver) = mpsc::channel(self.config.channel_buffer_size);
172
173 self.register_connection(&session_id, connection_id.clone(), sender).await;
175
176 let storage = self.storage.clone();
178 let session_id_clone = session_id.clone();
179 let connection_id_clone = connection_id.clone();
180 let config = self.config.clone();
181
182 let combined_stream = async_stream::stream! {
183 if let Some(after_event_id) = last_event_id {
185 debug!("Replaying events after ID {} for session={}, connection={}",
186 after_event_id, session_id_clone, connection_id_clone);
187
188 match storage.get_events_after(&session_id_clone, after_event_id).await {
189 Ok(events) => {
190 for event in events.into_iter().take(config.max_replay_events) {
191 yield event;
192 }
193 },
194 Err(e) => {
195 error!("Failed to get historical events: {}", e);
196 }
198 }
199 }
200
201 let mut keepalive_interval = tokio::time::interval(
203 tokio::time::Duration::from_secs(config.keepalive_interval_seconds)
204 );
205
206 loop {
207 tokio::select! {
208 event = receiver.recv() => {
210 match event {
211 Some(event) => {
212 debug!("๐จ Received event for connection {}: {}", connection_id_clone, event.event_type);
213 yield event;
214 },
215 None => {
216 debug!("Connection channel closed for session={}, connection={}", session_id_clone, connection_id_clone);
217 break;
218 }
219 }
220 },
221
222 _ = keepalive_interval.tick() => {
224 let keepalive_event = SseEvent {
225 id: 0, timestamp: chrono::Utc::now().timestamp_millis() as u64,
227 event_type: "ping".to_string(),
228 data: serde_json::json!({"type": "keepalive"}),
229 retry: None,
230 };
231 yield keepalive_event;
232 }
233 }
234 }
235
236 debug!("๐งน Cleaning up connection: session={}, connection={}", session_id_clone, connection_id_clone);
238 };
239
240 Ok(SseStream {
241 stream: Some(Box::pin(combined_stream)),
242 session_id,
243 connection_id,
244 })
245 }
246
247 async fn register_connection(
249 &self,
250 session_id: &str,
251 connection_id: ConnectionId,
252 sender: mpsc::Sender<SseEvent>
253 ) {
254 let mut connections = self.connections.write().await;
255
256 debug!("[{}] ๐ BEFORE registration: HashMap has {} sessions", self.instance_id, connections.len());
257 for (sid, conns) in connections.iter() {
258 debug!("[{}] ๐ Existing session before: {} with {} connections", self.instance_id, sid, conns.len());
259 }
260
261 let session_connections = connections.entry(session_id.to_string())
263 .or_insert_with(HashMap::new);
264
265 session_connections.insert(connection_id.clone(), sender);
267
268 debug!("[{}] ๐ Registered connection: session={}, connection={}, total_connections={}",
269 self.instance_id, session_id, connection_id, session_connections.len());
270
271 debug!("[{}] ๐ AFTER registration: HashMap has {} sessions", self.instance_id, connections.len());
272 for (sid, conns) in connections.iter() {
273 debug!("[{}] ๐ Session after: {} with {} connections", self.instance_id, sid, conns.len());
274 }
275 }
276
277 pub async fn unregister_connection(&self, session_id: &str, connection_id: &ConnectionId) {
279 debug!("๐ด UNREGISTER called for session={}, connection={}", session_id, connection_id);
280 let mut connections = self.connections.write().await;
281
282 debug!("๐ BEFORE unregister: HashMap has {} sessions", connections.len());
283
284 if let Some(session_connections) = connections.get_mut(session_id) {
285 if session_connections.remove(connection_id).is_some() {
286 debug!("๐ Unregistered connection: session={}, connection={}", session_id, connection_id);
287
288 if session_connections.is_empty() {
290 connections.remove(session_id);
291 debug!("๐งน Removed empty session: {}", session_id);
292 }
293 }
294 }
295
296 debug!("๐ AFTER unregister: HashMap has {} sessions", connections.len());
297 }
298
299 async fn stream_to_response(&self, mut sse_stream: SseStream) -> Response<http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>> {
301 let session_id = sse_stream.session_id().to_string();
303 let stream_identifier = sse_stream.stream_identifier();
304
305 info!("Converting SSE stream to HTTP response: {}", stream_identifier);
307 debug!("Stream details: session_id={}", session_id);
308
309 let stream = sse_stream.stream.take().expect("Stream should be present in SseStream");
312
313 let formatted_stream = stream.map(|event| {
314 let sse_formatted = event.format();
315 debug!("๐ก Streaming SSE event: id={}, event_type={}", event.id, event.event_type);
316 Ok(hyper::body::Frame::data(Bytes::from(sse_formatted)))
317 });
318
319 let body = StreamBody::new(formatted_stream).boxed_unsync();
321
322 Response::builder()
324 .status(StatusCode::OK)
325 .header(CONTENT_TYPE, "text/event-stream")
326 .header(CACHE_CONTROL, "no-cache")
327 .header(ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
328 .header("Connection", "keep-alive")
329 .body(body)
330 .unwrap()
331 }
332
333 pub async fn has_connections(&self, session_id: &str) -> bool {
335 let connections = self.connections.read().await;
336 connections.get(session_id)
337 .map(|session_connections| !session_connections.is_empty())
338 .unwrap_or(false)
339 }
340
341 pub async fn broadcast_to_session(
343 &self,
344 session_id: &str,
345 event_type: String,
346 data: Value,
347 ) -> Result<u64, StreamError> {
348 self.broadcast_to_session_with_options(session_id, event_type, data, true).await
349 }
350
351 pub async fn broadcast_to_session_with_options(
353 &self,
354 session_id: &str,
355 event_type: String,
356 data: Value,
357 store_when_no_connections: bool,
358 ) -> Result<u64, StreamError> {
359 if !store_when_no_connections && !self.has_connections(session_id).await {
361 debug!("๐ซ Suppressing notification for session {} (no connections, store_when_no_connections=false)", session_id);
362 return Err(StreamError::NoConnections(session_id.to_string()));
363 }
364
365 let event = SseEvent::new(event_type.clone(), data);
367
368 let stored_event = self.storage.store_event(session_id, event).await
370 .map_err(|e| StreamError::StorageError(e.to_string()))?;
371
372 let connections = self.connections.read().await;
374 debug!("[{}] ๐ Checking connections for session {}: connections hashmap has {} sessions",
375 self.instance_id, session_id, connections.len());
376
377 if let Some(session_connections) = connections.get(session_id) {
378 debug!("๐ Session {} found with {} connections", session_id, session_connections.len());
379
380 if !session_connections.is_empty() {
381 let (selected_connection_id, selected_sender) = session_connections.iter().next().unwrap();
383
384 if selected_sender.is_closed() {
386 warn!("๐ Sender is closed for connection: session={}, connection={}",
387 session_id, selected_connection_id);
388 debug!("๐ญ Connection sender was closed, event stored for reconnection");
389 } else {
390 debug!("โ
Sender is open, attempting to send to connection: session={}, connection={}",
391 session_id, selected_connection_id);
392
393 match selected_sender.try_send(stored_event.clone()) {
394 Ok(()) => {
395 info!("โ
Sent notification to ONE connection: session={}, connection={}, event_id={}, method={}",
396 session_id, selected_connection_id, stored_event.id, stored_event.event_type);
397 },
398 Err(mpsc::error::TrySendError::Full(_)) => {
399 warn!("โ ๏ธ Connection buffer full: session={}, connection={}", session_id, selected_connection_id);
400 },
402 Err(mpsc::error::TrySendError::Closed(_)) => {
403 warn!("๐ Connection closed during send: session={}, connection={}", session_id, selected_connection_id);
404 }
406 }
407 }
408 } else {
409 debug!("๐ญ No active connections for session: {} (event stored for reconnection)", session_id);
410 }
411 } else {
412 debug!("๐ญ No connections registered for session: {} (event stored for reconnection)", session_id);
413
414 for (sid, conns) in connections.iter() {
416 debug!("๐ Available session: {} with {} connections", sid, conns.len());
417 }
418 }
419
420 Ok(stored_event.id)
421 }
422
423 pub async fn broadcast_to_all_sessions(
425 &self,
426 event_type: String,
427 data: Value,
428 ) -> Result<Vec<String>, StreamError> {
429 let session_ids = self.storage.list_sessions().await
431 .map_err(|e| StreamError::StorageError(e.to_string()))?;
432
433 let mut failed_sessions = Vec::new();
434
435 for session_id in session_ids {
436 if let Err(e) = self.broadcast_to_session(&session_id, event_type.clone(), data.clone()).await {
437 error!("Failed to broadcast to session {}: {}", session_id, e);
438 failed_sessions.push(session_id);
439 }
440 }
441
442 Ok(failed_sessions)
443 }
444
445 pub async fn cleanup_connections(&self) -> usize {
447 debug!("๐งน CLEANUP_CONNECTIONS called");
448 let mut connections = self.connections.write().await;
449 let mut total_cleaned = 0;
450
451 debug!("๐ BEFORE cleanup: HashMap has {} sessions", connections.len());
452
453 connections.retain(|session_id, session_connections| {
455 let initial_count = session_connections.len();
456
457 session_connections.retain(|connection_id, sender| {
459 if sender.is_closed() {
460 debug!("๐งน Cleaned up closed connection: session={}, connection={}", session_id, connection_id);
461 false
462 } else {
463 true
464 }
465 });
466
467 let cleaned_count = initial_count - session_connections.len();
468 total_cleaned += cleaned_count;
469
470 !session_connections.is_empty()
472 });
473
474 if total_cleaned > 0 {
475 info!("Cleaned up {} inactive connections", total_cleaned);
476 }
477
478 total_cleaned
479 }
480
481 pub async fn create_post_sse_stream(
483 &self,
484 session_id: String,
485 response: turul_mcp_json_rpc_server::JsonRpcResponse,
486 ) -> Result<hyper::Response<http_body_util::combinators::BoxBody<bytes::Bytes, std::convert::Infallible>>, StreamError> {
487 if self.storage.get_session(&session_id).await
489 .map_err(|e| StreamError::StorageError(e.to_string()))?
490 .is_none()
491 {
492 return Err(StreamError::SessionNotFound(session_id));
493 }
494
495 info!("Creating POST SSE stream for session: {}", session_id);
496
497 let response_json = serde_json::to_string(&response)
499 .map_err(|e| StreamError::StorageError(format!("Failed to serialize response: {}", e)))?;
500
501 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
505
506 let mut sse_frames = Vec::new();
507 let mut event_id_counter = 1;
508
509 if let Ok(events) = self.storage.get_recent_events(&session_id, 10).await {
510 for event in events {
511 if event.event_type != "ping" { let notification_sse = format!(
514 "id: {}\nevent: {}\ndata: {}\n\n",
515 event_id_counter,
516 event.event_type, event.data
518 );
519 debug!("๐ค Including notification in POST SSE stream: id={}, event_type={}", event_id_counter, event.event_type);
520 sse_frames.push(http_body::Frame::data(Bytes::from(notification_sse)));
521 event_id_counter += 1;
522 }
523 }
524 }
525
526 let response_sse = format!(
528 "id: {}\nevent: result\ndata: {}\n\n", event_id_counter,
530 response_json
531 );
532 debug!("๐ค Sending JSON-RPC response as SSE event: id={}, event=result", event_id_counter);
533 sse_frames.push(http_body::Frame::data(Bytes::from(response_sse)));
534
535 let stream = futures::stream::iter(sse_frames.into_iter().map(Ok::<_, std::convert::Infallible>));
537
538 let body = StreamBody::new(stream);
540 let boxed_body = http_body_util::combinators::BoxBody::new(body);
541
542 debug!("๐ก POST SSE streaming response created: session={}", session_id);
543
544 Ok(hyper::Response::builder()
546 .status(hyper::StatusCode::OK)
547 .header(hyper::header::CONTENT_TYPE, "text/event-stream")
548 .header(hyper::header::CACHE_CONTROL, "no-cache")
549 .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, &self.config.cors_origin)
550 .header("Connection", "keep-alive")
551 .header("X-Accel-Buffering", "no") .header("Mcp-Session-Id", &session_id)
553 .body(boxed_body)
554 .unwrap())
555 }
556
557 pub async fn get_stats(&self) -> StreamStats {
559 let connections = self.connections.read().await;
560 let session_count = self.storage.session_count().await.unwrap_or(0);
561 let event_count = self.storage.event_count().await.unwrap_or(0);
562
563 let total_connections: usize = connections.values()
565 .map(|session_connections| session_connections.len())
566 .sum();
567
568 StreamStats {
569 active_broadcasters: total_connections, total_sessions: session_count,
571 total_events: event_count,
572 channel_buffer_size: self.config.channel_buffer_size,
573 }
574 }
575}
576
577impl Drop for StreamManager {
578 fn drop(&mut self) {
579 debug!("๐ฅ DROP: StreamManager instance {} - this may cause connection loss!",
580 self.instance_id);
581 debug!("๐ฅ If this appears during request processing, it indicates architecture problem");
582 }
583}
584
585#[derive(Debug, Clone)]
587pub struct StreamStats {
588 pub active_broadcasters: usize,
589 pub total_sessions: usize,
590 pub total_events: usize,
591 pub channel_buffer_size: usize,
592}
593
594#[cfg(not(test))]
596use async_stream;
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use turul_mcp_session_storage::{InMemorySessionStorage, SessionStorage};
602 use turul_mcp_protocol::ServerCapabilities;
603
604 #[tokio::test]
605 async fn test_stream_manager_creation() {
606 let storage = Arc::new(InMemorySessionStorage::new());
607 let manager = StreamManager::new(storage);
608
609 let stats = manager.get_stats().await;
610 assert_eq!(stats.active_broadcasters, 0);
611 assert_eq!(stats.total_sessions, 0);
612 }
613
614 #[tokio::test]
615 async fn test_broadcast_to_session() {
616 let storage = Arc::new(InMemorySessionStorage::new());
617 let manager = StreamManager::new(storage.clone());
618
619 let session = storage.create_session(ServerCapabilities::default()).await.unwrap();
621 let session_id = session.session_id.clone();
622
623 let event_id = manager.broadcast_to_session(
625 &session_id,
626 "test".to_string(),
627 serde_json::json!({"message": "test"})
628 ).await.unwrap();
629
630 assert!(event_id > 0);
631
632 let events = storage.get_events_after(&session_id, 0).await.unwrap();
634 assert_eq!(events.len(), 1);
635 assert_eq!(events[0].id, event_id);
636 }
637}