1use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use async_trait::async_trait;
19use serde_json::Value;
20use tokio::sync::{RwLock, broadcast};
21use tracing::{debug, error, info, warn};
22use uuid::Uuid;
23
24use turul_mcp_protocol::{ClientCapabilities, Implementation, McpVersion, ServerCapabilities};
25use turul_mcp_session_storage::{SessionStorage, SessionStorageError, SessionView};
26
27type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
28
29type GetStateFn = Arc<dyn Fn(&str) -> BoxFuture<Option<Value>> + Send + Sync>;
44type SetStateFn = Arc<dyn Fn(&str, Value) -> BoxFuture<()> + Send + Sync>;
45type RemoveStateFn = Arc<dyn Fn(&str) -> BoxFuture<Option<Value>> + Send + Sync>;
46
47#[derive(Clone)]
48pub struct SessionContext {
49    pub session_id: String,
51    pub get_state: GetStateFn,
53    pub set_state: SetStateFn,
55    pub remove_state: RemoveStateFn,
57    pub is_initialized: Arc<dyn Fn() -> BoxFuture<bool> + Send + Sync>,
59    pub send_notification: Arc<dyn Fn(SessionEvent) -> BoxFuture<()> + Send + Sync>,
61    pub broadcaster: Option<Arc<dyn std::any::Any + Send + Sync>>,
63}
64
65impl SessionContext {
66    pub(crate) fn from_json_rpc_with_broadcaster(
68        json_rpc_ctx: turul_mcp_json_rpc_server::SessionContext,
69        storage: Arc<dyn SessionStorage<Error = SessionStorageError>>,
70    ) -> Self {
71        let session_id = json_rpc_ctx.session_id.clone();
72        let broadcaster = json_rpc_ctx.broadcaster.clone();
73
74        let get_state = {
76            let storage = storage.clone();
77            let session_id = session_id.clone();
78            Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
79                let storage = storage.clone();
80                let session_id = session_id.clone();
81                let key = key.to_string();
82                Box::pin(async move {
83                    match storage.get_session_state(&session_id, &key).await {
84                        Ok(Some(value)) => Some(value),
85                        Ok(None) => None,
86                        Err(e) => {
87                            tracing::warn!("Failed to get session state for key '{}': {}", key, e);
88                            None
89                        }
90                    }
91                })
92            })
93        };
94
95        let set_state = {
96            let storage = storage.clone();
97            let session_id = session_id.clone();
98            Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
99                let storage = storage.clone();
100                let session_id = session_id.clone();
101                let key = key.to_string();
102                Box::pin(async move {
103                    if let Err(e) = storage.set_session_state(&session_id, &key, value).await {
104                        tracing::error!("Failed to set session state for key '{}': {}", key, e);
105                    }
106                })
107            })
108        };
109
110        let remove_state = {
111            let storage = storage.clone();
112            let session_id = session_id.clone();
113            Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
114                let storage = storage.clone();
115                let session_id = session_id.clone();
116                let key = key.to_string();
117                Box::pin(async move {
118                    match storage.remove_session_state(&session_id, &key).await {
119                        Ok(value) => value,
120                        Err(e) => {
121                            tracing::warn!(
122                                "Failed to remove session state for key '{}': {}",
123                                key,
124                                e
125                            );
126                            None
127                        }
128                    }
129                })
130            })
131        };
132
133        let is_initialized = {
134            let storage = storage.clone();
135            let session_id = session_id.clone();
136            Arc::new(move || -> BoxFuture<bool> {
137                let storage = storage.clone();
138                let session_id = session_id.clone();
139                Box::pin(async move {
140                    match storage.get_session(&session_id).await {
141                        Ok(Some(session_info)) => session_info.is_initialized,
142                        Ok(None) => {
143                            tracing::warn!("Session {} not found in storage", session_id);
144                            false
145                        }
146                        Err(e) => {
147                            tracing::error!("Failed to check session initialization: {}", e);
148                            false
149                        }
150                    }
151                })
152            })
153        };
154
155        let send_notification = {
157            let session_id = session_id.clone();
158            let broadcaster = broadcaster.clone();
159            Arc::new(move |event: SessionEvent| -> BoxFuture<()> {
160                let session_id = session_id.clone();
161                let broadcaster = broadcaster.clone();
162                Box::pin(async move {
163                    debug!(
164                        "📨 SessionContext.send_notification() called for session {}: {:?}",
165                        session_id, event
166                    );
167
168                    if let Some(broadcaster_any) = &broadcaster {
170                        debug!(
171                            "✅ NotificationBroadcaster available for session: {}",
172                            session_id
173                        );
174
175                        match event {
177                            SessionEvent::Notification(json_value) => {
178                                debug!(
179                                    "🔧 Attempting to send notification via StreamManagerNotificationBroadcaster"
180                                );
181                                debug!("📦 Notification JSON: {}", json_value);
182
183                                match parse_and_send_notification_with_broadcaster(
185                                    &session_id,
186                                    &json_value,
187                                    broadcaster_any,
188                                )
189                                .await
190                                {
191                                    Ok(_) => debug!(
192                                        "✅ Bridge working: Successfully processed notification for session {}",
193                                        session_id
194                                    ),
195                                    Err(e) => error!(
196                                        "❌ Bridge error: Failed to process notification for session {}: {}",
197                                        session_id, e
198                                    ),
199                                }
200                            }
201                            _ => {
202                                debug!("⚠️ Non-notification event, ignoring: {:?}", event);
203                            }
204                        }
205                    } else {
206                        debug!("⚠️ No broadcaster available for session {}", session_id);
207                    }
208                })
209            })
210        };
211
212        SessionContext {
213            session_id,
214            get_state,
215            set_state,
216            remove_state,
217            is_initialized,
218            send_notification,
219            broadcaster,
220        }
221    }
222
223    pub fn has_broadcaster(&self) -> bool {
225        self.broadcaster.is_some()
226    }
227
228    pub fn get_raw_broadcaster(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
230        self.broadcaster.clone()
231    }
232
233    #[cfg(feature = "test-utils")]
235    pub fn from_json_rpc_with_broadcaster_for_tests(
236        json_rpc_ctx: turul_mcp_json_rpc_server::SessionContext,
237        storage: Arc<dyn SessionStorage<Error = SessionStorageError>>,
238    ) -> Self {
239        Self::from_json_rpc_with_broadcaster(json_rpc_ctx, storage)
240    }
241
242    pub async fn get_typed_state<T>(&self, key: &str) -> Option<T>
244    where
245        T: serde::de::DeserializeOwned,
246    {
247        (self.get_state)(key)
248            .await
249            .and_then(|v| serde_json::from_value(v).ok())
250    }
251
252    pub async fn set_typed_state<T>(&self, key: &str, value: T) -> Result<(), String>
254    where
255        T: serde::Serialize,
256    {
257        match serde_json::to_value(value) {
258            Ok(json_value) => {
259                (self.set_state)(key, json_value).await;
260                Ok(())
261            }
262            Err(e) => Err(format!("Failed to serialize value: {}", e)),
263        }
264    }
265
266    #[cfg(test)]
268    pub fn new_test() -> Self {
269        use std::collections::HashMap;
270        use std::sync::Arc;
271        use tokio::sync::RwLock;
272
273        let state = Arc::new(RwLock::new(HashMap::<String, Value>::new()));
274
275        let get_state = {
276            let state = state.clone();
277            Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
278                let state = state.clone();
279                let key = key.to_string();
280                Box::pin(async move { state.read().await.get(&key).cloned() })
281            })
282        };
283
284        let set_state = {
285            let state = state.clone();
286            Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
287                let state = state.clone();
288                let key = key.to_string();
289                Box::pin(async move {
290                    state.write().await.insert(key, value);
291                })
292            })
293        };
294
295        let remove_state = {
296            let state = state.clone();
297            Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
298                let state = state.clone();
299                let key = key.to_string();
300                Box::pin(async move { state.write().await.remove(&key) })
301            })
302        };
303
304        let is_initialized = Arc::new(|| -> BoxFuture<bool> { Box::pin(async { true }) });
305
306        let send_notification = Arc::new(|_event: SessionEvent| -> BoxFuture<()> {
307            Box::pin(async {})
308        });
309
310        SessionContext {
311            session_id: Uuid::now_v7().to_string(),
312            get_state,
313            set_state,
314            remove_state,
315            is_initialized,
316            send_notification,
317            broadcaster: None,
318        }
319    }
320
321    pub async fn notify(&self, event: SessionEvent) {
323        debug!(
324            "📨 SessionContext.notify() called for session {}: {:?}",
325            self.session_id, event
326        );
327        (self.send_notification)(event).await;
328        debug!("🚀 SessionContext.notify() send_notification closure completed");
329    }
330
331    pub async fn notify_progress(&self, progress_token: impl Into<String>, progress: u64) {
333        if self.has_broadcaster() {
334            debug!(
335                "🔔 notify_progress using NotificationBroadcaster for session: {}",
336                self.session_id
337            );
338            } else {
340            debug!(
341                "🔔 notify_progress using OLD SessionManager for session: {}",
342                self.session_id
343            );
344        }
345        let mut other = std::collections::HashMap::new();
346        other.insert(
347            "progressToken".to_string(),
348            serde_json::json!(progress_token.into()),
349        );
350        other.insert("progress".to_string(), serde_json::json!(progress));
351
352        let params = turul_mcp_protocol::RequestParams { meta: None, other };
353        let notification =
354            turul_mcp_protocol::JsonRpcNotification::new("notifications/progress".to_string())
355                .with_params(params);
356        self.notify(SessionEvent::Notification(
357            serde_json::to_value(notification).unwrap(),
358        ))
359        .await;
360    }
361
362    pub async fn notify_progress_with_total(
364        &self,
365        progress_token: impl Into<String>,
366        progress: u64,
367        total: u64,
368    ) {
369        let mut other = std::collections::HashMap::new();
370        other.insert(
371            "progressToken".to_string(),
372            serde_json::json!(progress_token.into()),
373        );
374        other.insert("progress".to_string(), serde_json::json!(progress));
375        other.insert("total".to_string(), serde_json::json!(total));
376
377        let params = turul_mcp_protocol::RequestParams { meta: None, other };
378        let notification =
379            turul_mcp_protocol::JsonRpcNotification::new("notifications/progress".to_string())
380                .with_params(params);
381        self.notify(SessionEvent::Notification(
382            serde_json::to_value(notification).unwrap(),
383        ))
384        .await;
385    }
386
387    pub async fn notify_log(
389        &self,
390        level: turul_mcp_protocol::logging::LoggingLevel,
391        data: serde_json::Value,
392        logger: Option<String>,
393        meta: Option<std::collections::HashMap<String, serde_json::Value>>,
394    ) {
395        let message_level = level;
397
398        if !self.should_log(message_level).await {
400            let threshold = self.get_logging_level().await;
401            debug!(
402                "🔕 Filtering out {:?} level message for session {} (threshold: {:?})",
403                message_level, self.session_id, threshold
404            );
405            return;
406        }
407
408        let threshold = self.get_logging_level().await;
409        debug!(
410            "📢 Sending {:?} level message to session {} (threshold: {:?})",
411            message_level, self.session_id, threshold
412        );
413
414        use turul_mcp_protocol::notifications::LoggingMessageNotification;
416        let mut notification = LoggingMessageNotification::new(message_level, data);
417
418        if let Some(logger) = logger {
420            notification = notification.with_logger(logger);
421        }
422
423        if let Some(meta) = meta {
425            notification = notification.with_meta(meta);
426        }
427
428        if self.has_broadcaster() {
429            debug!(
430                "🔔 notify_log using NotificationBroadcaster for session: {}",
431                self.session_id
432            );
433            self.notify(SessionEvent::Notification(
435                serde_json::to_value(notification).unwrap(),
436            ))
437            .await;
438            return;
439        } else {
440            debug!(
441                "🔔 notify_log using OLD SessionManager for session: {}",
442                self.session_id
443            );
444        }
445
446        self.notify(SessionEvent::Notification(
448            serde_json::to_value(notification).unwrap(),
449        ))
450        .await;
451    }
452
453    pub async fn notify_resources_changed(&self) {
455        let notification = turul_mcp_protocol::JsonRpcNotification::new(
456            "notifications/resources/listChanged".to_string(),
457        );
458        self.notify(SessionEvent::Notification(
459            serde_json::to_value(notification).unwrap(),
460        ))
461        .await;
462    }
463
464    pub async fn notify_resource_updated(&self, uri: impl Into<String>) {
466        let mut other = std::collections::HashMap::new();
467        other.insert("uri".to_string(), serde_json::json!(uri.into()));
468
469        let params = turul_mcp_protocol::RequestParams { meta: None, other };
470        let notification = turul_mcp_protocol::JsonRpcNotification::new(
471            "notifications/resources/updated".to_string(),
472        )
473        .with_params(params);
474        self.notify(SessionEvent::Notification(
475            serde_json::to_value(notification).unwrap(),
476        ))
477        .await;
478    }
479
480    pub async fn notify_tools_changed(&self) {
482        let notification = turul_mcp_protocol::JsonRpcNotification::new(
483            "notifications/tools/listChanged".to_string(),
484        );
485        self.notify(SessionEvent::Notification(
486            serde_json::to_value(notification).unwrap(),
487        ))
488        .await;
489    }
490
491    pub async fn get_logging_level(&self) -> turul_mcp_protocol::logging::LoggingLevel {
497        use turul_mcp_protocol::logging::LoggingLevel;
498
499        if let Some(level_value) = (self.get_state)("mcp:logging:level").await {
501            if let Some(level_str) = level_value.as_str() {
502                match level_str {
503                    "debug" => LoggingLevel::Debug,
504                    "info" => LoggingLevel::Info,
505                    "notice" => LoggingLevel::Notice,
506                    "warning" => LoggingLevel::Warning,
507                    "error" => LoggingLevel::Error,
508                    "critical" => LoggingLevel::Critical,
509                    "alert" => LoggingLevel::Alert,
510                    "emergency" => LoggingLevel::Emergency,
511                    _ => LoggingLevel::Info, }
513            } else {
514                LoggingLevel::Info }
516        } else {
517            LoggingLevel::Info }
519    }
520
521    pub async fn set_logging_level(&self, level: turul_mcp_protocol::logging::LoggingLevel) {
523        use turul_mcp_protocol::logging::LoggingLevel;
524
525        let level_str = match level {
526            LoggingLevel::Debug => "debug",
527            LoggingLevel::Info => "info",
528            LoggingLevel::Notice => "notice",
529            LoggingLevel::Warning => "warning",
530            LoggingLevel::Error => "error",
531            LoggingLevel::Critical => "critical",
532            LoggingLevel::Alert => "alert",
533            LoggingLevel::Emergency => "emergency",
534        };
535
536        (self.set_state)("mcp:logging:level", serde_json::json!(level_str)).await;
537        debug!(
538            "🎯 Set logging level for session {}: {:?}",
539            self.session_id, level
540        );
541    }
542
543    pub async fn should_log(
545        &self,
546        message_level: turul_mcp_protocol::logging::LoggingLevel,
547    ) -> bool {
548        let session_threshold = self.get_logging_level().await;
549        message_level.should_log(session_threshold)
550    }
551
552    pub fn should_log_sync(
554        &self,
555        message_level: turul_mcp_protocol::logging::LoggingLevel,
556    ) -> bool {
557        let session_level = futures::executor::block_on(self.get_logging_level());
559        message_level.should_log(session_level)
560    }
561}
562
563#[async_trait]
573impl SessionView for SessionContext {
574    fn session_id(&self) -> &str {
575        &self.session_id
576    }
577
578    async fn get_state(&self, key: &str) -> Result<Option<Value>, String> {
579        Ok((self.get_state)(key).await)
580    }
581
582    async fn set_state(&self, key: &str, value: Value) -> Result<(), String> {
583        (self.set_state)(key, value).await;
584        Ok(())
585    }
586
587    async fn get_metadata(&self, key: &str) -> Result<Option<Value>, String> {
588        let metadata_key = format!("__meta__:{}", key);
590        Ok((self.get_state)(&metadata_key).await)
591    }
592
593    async fn set_metadata(&self, key: &str, value: Value) -> Result<(), String> {
594        let metadata_key = format!("__meta__:{}", key);
596        (self.set_state)(&metadata_key, value).await;
597        Ok(())
598    }
599}
600
601impl turul_mcp_builders::logging::LoggingTarget for SessionContext {
607    fn should_log(&self, level: turul_mcp_protocol::logging::LoggingLevel) -> bool {
608        self.should_log_sync(level)
609    }
610
611    fn notify_log(
612        &self,
613        level: turul_mcp_protocol::logging::LoggingLevel,
614        data: serde_json::Value,
615        logger: Option<String>,
616        meta: Option<std::collections::HashMap<String, serde_json::Value>>,
617    ) {
618        let session_ctx = self.clone();
620        tokio::spawn(async move {
621            session_ctx.notify_log(level, data, logger, meta).await;
622        });
623    }
624}
625
626async fn parse_and_send_notification_with_broadcaster(
628    session_id: &str,
629    json_value: &Value,
630    broadcaster_any: &Arc<dyn std::any::Any + Send + Sync>,
631) -> Result<(), String> {
632    debug!(
633        "🔍 Parsing notification JSON for session {}: {:?}",
634        session_id, json_value
635    );
636
637    use turul_http_mcp_server::notification_bridge::SharedNotificationBroadcaster;
639    use turul_mcp_protocol::notifications::{LoggingMessageNotification, ProgressNotification};
640    debug!(
642        "🔍 Attempting downcast for session {}, broadcaster type: {:?}",
643        session_id,
644        std::any::type_name::<SharedNotificationBroadcaster>()
645    );
646    if let Some(broadcaster) = broadcaster_any.downcast_ref::<SharedNotificationBroadcaster>() {
647        debug!(
648            "✅ Successfully downcast broadcaster for session {}",
649            session_id
650        );
651
652        if let Some(method) = json_value.get("method").and_then(|v| v.as_str()) {
654            match method {
655                "notifications/message" => {
656                    debug!(
657                        "📝 Message notification detected, deserializing directly to LoggingMessageNotification"
658                    );
659
660                    match serde_json::from_value::<LoggingMessageNotification>(json_value.clone()) {
662                        Ok(notification) => {
663                            debug!(
664                                "✅ Successfully deserialized LoggingMessageNotification: level={:?}, logger={:?}",
665                                notification.params.level, notification.params.logger
666                            );
667
668                            debug!(
669                                "🔧 About to call broadcaster.send_message_notification() for session {}",
670                                session_id
671                            );
672                            match broadcaster
674                                .send_message_notification(session_id, notification)
675                                .await
676                            {
677                                Ok(()) => {
678                                    debug!(
679                                        "🎉 SUCCESS: LoggingMessageNotification sent to StreamManager for session {}",
680                                        session_id
681                                    );
682                                    debug!(
683                                        "🚀 Streamable HTTP Transport Bridge: Complete end-to-end delivery confirmed!"
684                                    );
685                                    return Ok(());
686                                }
687                                Err(e) => {
688                                    error!(
689                                        "❌ Failed to send LoggingMessageNotification to StreamManager: {}",
690                                        e
691                                    );
692                                    return Err(format!(
693                                        "Failed to send LoggingMessageNotification: {}",
694                                        e
695                                    ));
696                                }
697                            }
698                        }
699                        Err(e) => {
700                            error!("❌ Failed to deserialize LoggingMessageNotification: {}", e);
701                            return Err(format!(
702                                "Failed to deserialize LoggingMessageNotification: {}",
703                                e
704                            ));
705                        }
706                    }
707                }
708                "notifications/progress" => {
709                    if let Some(params) = json_value.get("params")
710                        && let Some(token) = params.get("progressToken").and_then(|v| v.as_str())
711                    {
712                        debug!("📊 Progress notification detected: token={}", token);
713
714                        let progress = params.get("progress").and_then(|v| v.as_u64()).unwrap_or(0);
716
717                        let notification = ProgressNotification {
719                            method: "notifications/progress".to_string(),
720                            params: turul_mcp_protocol::notifications::ProgressNotificationParams {
721                                progress_token: token.to_string(),
722                                progress,
723                                total: params.get("total").and_then(|v| v.as_u64()),
724                                message: params
725                                    .get("message")
726                                    .and_then(|v| v.as_str())
727                                    .map(|s| s.to_string()),
728                                meta: None,
729                            },
730                        };
731
732                        debug!(
733                            "🔧 About to call broadcaster.send_progress_notification() for session {}",
734                            session_id
735                        );
736                        match broadcaster
738                            .send_progress_notification(session_id, notification)
739                            .await
740                        {
741                            Ok(()) => {
742                                debug!(
743                                    "🎉 SUCCESS: ProgressNotification sent to StreamManager for session {}",
744                                    session_id
745                                );
746                                debug!(
747                                    "🚀 Streamable HTTP Transport Bridge: Complete end-to-end delivery confirmed!"
748                                );
749                                return Ok(());
750                            }
751                            Err(e) => {
752                                error!(
753                                    "❌ Failed to send ProgressNotification to StreamManager: {}",
754                                    e
755                                );
756                                return Err(format!("Failed to send ProgressNotification: {}", e));
757                            }
758                        }
759                    }
760                }
761                _ => {
762                    debug!(
763                        "🔧 Other notification method: {} - sending as generic JsonRpcNotification",
764                        method
765                    );
766
767                    let params_map: std::collections::HashMap<String, serde_json::Value> =
769                        json_value
770                            .get("params")
771                            .and_then(|p| p.as_object())
772                            .unwrap_or(&serde_json::Map::new())
773                            .iter()
774                            .map(|(k, v)| (k.clone(), v.clone()))
775                            .collect();
776                    let json_rpc_notification =
777                        turul_mcp_json_rpc_server::JsonRpcNotification::new_with_object_params(
778                            method.to_string(),
779                            params_map,
780                        );
781
782                    match broadcaster
783                        .send_notification(session_id, json_rpc_notification)
784                        .await
785                    {
786                        Ok(()) => {
787                            debug!(
788                                "🎉 SUCCESS: Generic notification sent to StreamManager for session {}",
789                                session_id
790                            );
791                            return Ok(());
792                        }
793                        Err(e) => {
794                            error!(
795                                "❌ Failed to send generic notification to StreamManager: {}",
796                                e
797                            );
798                            return Err(format!("Failed to send generic notification: {}", e));
799                        }
800                    }
801                }
802            }
803        }
804    } else {
805        error!(
806            "❌ Failed to downcast broadcaster for session {}",
807            session_id
808        );
809        return Err("Failed to downcast broadcaster to SharedNotificationBroadcaster".to_string());
810    }
811
812    debug!(
813        "❓ Could not determine notification type for session {}",
814        session_id
815    );
816    Ok(())
817}
818
819#[derive(Debug, Clone)]
821pub enum SessionEvent {
822    Notification(Value),
824    KeepAlive,
826    Disconnect,
828    Custom { event_type: String, data: Value },
830}
831
832#[derive(Debug)]
834pub struct McpSession {
835    pub id: String,
837    pub created: Instant,
839    pub last_accessed: Instant,
841    pub mcp_version: McpVersion,
843    pub client_capabilities: Option<ClientCapabilities>,
845    pub server_capabilities: ServerCapabilities,
847    pub client_info: Option<Implementation>,
849    pub state: HashMap<String, Value>,
851    pub event_sender: broadcast::Sender<SessionEvent>,
853    pub initialized: bool,
855}
856
857impl McpSession {
858    pub fn new(server_capabilities: ServerCapabilities) -> Self {
860        let session_id = Uuid::now_v7().to_string();
861        let (event_sender, _) = broadcast::channel(128);
862
863        Self {
864            id: session_id,
865            created: Instant::now(),
866            last_accessed: Instant::now(),
867            mcp_version: McpVersion::CURRENT,
868            client_capabilities: None,
869            server_capabilities,
870            client_info: None,
871            state: HashMap::new(),
872            event_sender,
873            initialized: false,
874        }
875    }
876
877    pub fn touch(&mut self) {
879        self.last_accessed = Instant::now();
880    }
881
882    pub fn is_expired(&self, timeout: Duration) -> bool {
884        self.last_accessed.elapsed() > timeout
885    }
886
887    pub fn initialize(
889        &mut self,
890        client_info: Implementation,
891        client_capabilities: ClientCapabilities,
892    ) {
893        self.client_info = Some(client_info);
894        self.client_capabilities = Some(client_capabilities);
895        self.initialized = true;
896        self.touch();
897    }
898
899    pub fn initialize_with_version(
901        &mut self,
902        client_info: Implementation,
903        client_capabilities: ClientCapabilities,
904        mcp_version: McpVersion,
905    ) {
906        self.client_info = Some(client_info);
907        self.client_capabilities = Some(client_capabilities);
908        self.mcp_version = mcp_version;
909        self.initialized = true;
910        self.touch();
911    }
912
913    pub fn get_state(&self, key: &str) -> Option<Value> {
915        self.state.get(key).cloned()
916    }
917
918    pub fn set_state(&mut self, key: &str, value: Value) {
920        self.state.insert(key.to_string(), value);
921        self.touch();
922    }
923
924    pub fn remove_state(&mut self, key: &str) -> Option<Value> {
926        let result = self.state.remove(key);
927        if result.is_some() {
928            self.touch();
929        }
930        result
931    }
932
933    pub fn send_event(&self, event: SessionEvent) -> Result<(), String> {
935        self.event_sender
936            .send(event)
937            .map_err(|e| format!("Failed to send event: {}", e))?;
938        Ok(())
939    }
940
941    pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
943        self.event_sender.subscribe()
944    }
945}
946
947#[derive(Debug, thiserror::Error)]
949pub enum SessionError {
950    #[error("Session not found: {0}")]
951    NotFound(String),
952    #[error("Session expired: {0}")]
953    Expired(String),
954    #[error("Session not initialized: {0}")]
955    NotInitialized(String),
956    #[error("Invalid session data: {0}")]
957    InvalidData(String),
958    #[error("Storage error: {0}")]
959    StorageError(String),
960}
961
962pub struct SessionManager {
964    storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
966    sessions: RwLock<HashMap<String, McpSession>>,
968    session_timeout: Duration,
970    cleanup_interval: Duration,
972    default_capabilities: ServerCapabilities,
974    global_event_sender: broadcast::Sender<(String, SessionEvent)>,
976}
977
978impl SessionManager {
979    pub fn new(default_capabilities: ServerCapabilities) -> Self {
981        let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
982            Arc::new(turul_mcp_session_storage::InMemorySessionStorage::new());
983        Self::with_storage_and_timeouts(
984            storage,
985            default_capabilities,
986            Duration::from_secs(30 * 60), Duration::from_secs(60),      )
989    }
990
991    pub fn with_timeouts(
993        default_capabilities: ServerCapabilities,
994        session_timeout: Duration,
995        cleanup_interval: Duration,
996    ) -> Self {
997        let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
998            Arc::new(turul_mcp_session_storage::InMemorySessionStorage::new());
999        Self::with_storage_and_timeouts(
1000            storage,
1001            default_capabilities,
1002            session_timeout,
1003            cleanup_interval,
1004        )
1005    }
1006
1007    pub fn with_storage(
1009        storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
1010        default_capabilities: ServerCapabilities,
1011    ) -> Self {
1012        Self::with_storage_and_timeouts(
1013            storage,
1014            default_capabilities,
1015            Duration::from_secs(30 * 60), Duration::from_secs(60),      )
1018    }
1019
1020    pub fn with_storage_and_timeouts(
1022        storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
1023        default_capabilities: ServerCapabilities,
1024        session_timeout: Duration,
1025        cleanup_interval: Duration,
1026    ) -> Self {
1027        let (global_event_sender, _) = broadcast::channel(1000);
1028
1029        Self {
1030            storage,
1031            sessions: RwLock::new(HashMap::new()),
1032            session_timeout,
1033            cleanup_interval,
1034            default_capabilities,
1035            global_event_sender,
1036        }
1037    }
1038
1039    pub async fn create_session(&self) -> String {
1041        let session = McpSession::new(self.default_capabilities.clone());
1042        let session_id = session.id.clone();
1043
1044        debug!("Creating new session: {}", session_id);
1045
1046        match self
1048            .storage
1049            .create_session_with_id(session_id.clone(), self.default_capabilities.clone())
1050            .await
1051        {
1052            Ok(_) => debug!("Session {} created in storage backend", session_id),
1053            Err(e) => error!("Failed to create session {} in storage: {}", session_id, e),
1054        }
1055
1056        self.sessions
1058            .write()
1059            .await
1060            .insert(session_id.clone(), session);
1061        session_id
1062    }
1063
1064    pub async fn create_session_with_id(&self, session_id: String) -> String {
1066        let mut session = McpSession::new(self.default_capabilities.clone());
1067        session.id = session_id.clone();
1068
1069        debug!("Creating session with provided ID: {}", session_id);
1070
1071        match self
1073            .storage
1074            .create_session_with_id(session_id.clone(), self.default_capabilities.clone())
1075            .await
1076        {
1077            Ok(_) => debug!("Session {} created in storage backend", session_id),
1078            Err(e) => error!("Failed to create session {} in storage: {}", session_id, e),
1079        }
1080
1081        self.sessions
1083            .write()
1084            .await
1085            .insert(session_id.clone(), session);
1086        session_id
1087    }
1088
1089    pub async fn add_session_to_cache(
1092        &self,
1093        session_id: String,
1094        server_capabilities: ServerCapabilities,
1095    ) {
1096        let mut session = McpSession::new(server_capabilities);
1097        session.id = session_id.clone(); debug!("Adding externally created session {} to cache", session_id);
1100        self.sessions.write().await.insert(session_id, session);
1101    }
1102
1103    pub async fn load_session_from_storage(&self, session_id: &str) -> Result<bool, SessionError> {
1106        match self.storage.get_session(session_id).await {
1107            Ok(Some(session_info)) => {
1108                debug!("Loading session {} from storage", session_id);
1109
1110                let server_capabilities =
1112                    session_info.server_capabilities.clone().unwrap_or_else(|| {
1113                        warn!(
1114                            "Session {} in storage has no server capabilities, using defaults",
1115                            session_id
1116                        );
1117                        self.default_capabilities.clone()
1118                    });
1119
1120                let mut session = McpSession::new(server_capabilities);
1121                session.id = session_id.to_string();
1122                session.initialized = session_info.is_initialized;
1123                session.client_capabilities = session_info.client_capabilities.clone();
1124                session.state = session_info.state.clone();
1125
1126                let now = std::time::SystemTime::now()
1129                    .duration_since(std::time::UNIX_EPOCH)
1130                    .unwrap_or_default()
1131                    .as_millis() as u64;
1132
1133                let created_elapsed = if now > session_info.created_at {
1134                    Duration::from_millis(now - session_info.created_at)
1135                } else {
1136                    Duration::from_secs(0)
1137                };
1138
1139                let last_activity_elapsed = if now > session_info.last_activity {
1140                    Duration::from_millis(now - session_info.last_activity)
1141                } else {
1142                    Duration::from_secs(0)
1143                };
1144
1145                session.created = Instant::now() - created_elapsed;
1147                session.last_accessed = Instant::now() - last_activity_elapsed;
1148
1149                self.sessions
1151                    .write()
1152                    .await
1153                    .insert(session_id.to_string(), session);
1154
1155                debug!(
1156                    "Session {} loaded from storage: initialized={}, has_capabilities={}",
1157                    session_id,
1158                    session_info.is_initialized,
1159                    session_info.server_capabilities.is_some()
1160                );
1161
1162                Ok(true)
1163            }
1164            Ok(None) => {
1165                debug!("Session {} not found in storage", session_id);
1166                Ok(false)
1167            }
1168            Err(e) => {
1169                error!("Failed to get session {} from storage: {}", session_id, e);
1170                Err(SessionError::StorageError(e.to_string()))
1171            }
1172        }
1173    }
1174
1175    pub async fn touch_session(&self, session_id: &str) -> Result<(), SessionError> {
1177        let mut sessions = self.sessions.write().await;
1178        if let Some(session) = sessions.get_mut(session_id) {
1179            if session.is_expired(self.session_timeout) {
1180                sessions.remove(session_id);
1181                return Err(SessionError::Expired(session_id.to_string()));
1182            }
1183            session.touch();
1184            Ok(())
1185        } else {
1186            Err(SessionError::NotFound(session_id.to_string()))
1187        }
1188    }
1189
1190    pub async fn initialize_session(
1192        &self,
1193        session_id: &str,
1194        client_info: Implementation,
1195        client_capabilities: ClientCapabilities,
1196    ) -> Result<(), SessionError> {
1197        if let Ok(Some(mut session_info)) = self.storage.get_session(session_id).await {
1199            session_info.client_capabilities = Some(client_capabilities.clone());
1200            session_info.is_initialized = true;
1201            session_info.touch();
1202
1203            if let Err(e) = self.storage.update_session(session_info).await {
1204                error!("Failed to update session in storage: {}", e);
1205            }
1206        }
1207
1208        let mut sessions = self.sessions.write().await;
1210        if let Some(session) = sessions.get_mut(session_id) {
1211            session.initialize(client_info, client_capabilities);
1212            debug!("Session {} initialized", session_id);
1213            Ok(())
1214        } else {
1215            Err(SessionError::NotFound(session_id.to_string()))
1216        }
1217    }
1218
1219    pub async fn initialize_session_with_version(
1221        &self,
1222        session_id: &str,
1223        client_info: Implementation,
1224        client_capabilities: ClientCapabilities,
1225        mcp_version: McpVersion,
1226    ) -> Result<(), SessionError> {
1227        if let Ok(Some(mut session_info)) = self.storage.get_session(session_id).await {
1229            session_info.client_capabilities = Some(client_capabilities.clone());
1230            session_info.is_initialized = true;
1231            session_info.touch();
1232            if let Err(e) = self.storage.update_session(session_info).await {
1235                error!(
1236                    "❌ CRITICAL: Failed to update session {} in storage: {}",
1237                    session_id, e
1238                );
1239                return Err(SessionError::StorageError(format!(
1240                    "Failed to persist session initialization: {}",
1241                    e
1242                )));
1243            }
1244            debug!(
1245                "✅ Session {} storage updated with is_initialized=true",
1246                session_id
1247            );
1248        } else {
1249            error!(
1250                "❌ Session {} not found in storage during initialization",
1251                session_id
1252            );
1253            return Err(SessionError::NotFound(session_id.to_string()));
1254        }
1255
1256        let mut sessions = self.sessions.write().await;
1258        if let Some(session) = sessions.get_mut(session_id) {
1259            session.initialize_with_version(client_info, client_capabilities, mcp_version);
1260            debug!(
1261                "✅ Session {} cache updated with protocol version {}",
1262                session_id, mcp_version
1263            );
1264            Ok(())
1265        } else {
1266            warn!(
1267                "⚠️ Session {} not found in cache but exists in storage - creating cache entry",
1268                session_id
1269            );
1270            Ok(())
1273        }
1274    }
1275
1276    pub async fn session_exists(&self, session_id: &str) -> bool {
1278        match self.storage.get_session(session_id).await {
1280            Ok(Some(session_info)) => {
1281                let timeout_minutes = self.session_timeout.as_secs() / 60;
1283                !session_info.is_expired(timeout_minutes)
1284            }
1285            Ok(None) => false,
1286            Err(e) => {
1287                debug!("Storage backend error for session_exists: {}", e);
1288                let sessions = self.sessions.read().await;
1290                sessions
1291                    .get(session_id)
1292                    .map(|s| !s.is_expired(self.session_timeout))
1293                    .unwrap_or(false)
1294            }
1295        }
1296    }
1297
1298    pub async fn get_session_state(&self, session_id: &str, key: &str) -> Option<Value> {
1300        match self.storage.get_session_state(session_id, key).await {
1302            Ok(value) => value,
1303            Err(e) => {
1304                debug!("Storage backend error for get_session_state: {}", e);
1305                let sessions = self.sessions.read().await;
1307                sessions.get(session_id)?.get_state(key)
1308            }
1309        }
1310    }
1311
1312    pub async fn set_session_state(&self, session_id: &str, key: &str, value: Value) {
1314        if let Err(e) = self
1316            .storage
1317            .set_session_state(session_id, key, value.clone())
1318            .await
1319        {
1320            error!("Failed to set session state in storage: {}", e);
1321        }
1322
1323        let mut sessions = self.sessions.write().await;
1325        if let Some(session) = sessions.get_mut(session_id) {
1326            session.set_state(key, value);
1327        }
1328    }
1329
1330    pub async fn remove_session_state(&self, session_id: &str, key: &str) -> Option<Value> {
1332        let storage_result = match self.storage.remove_session_state(session_id, key).await {
1334            Ok(value) => value,
1335            Err(e) => {
1336                error!("Failed to remove session state from storage: {}", e);
1337                None
1338            }
1339        };
1340
1341        let mut sessions = self.sessions.write().await;
1343        let memory_result = sessions.get_mut(session_id)?.remove_state(key);
1344
1345        storage_result.or(memory_result)
1347    }
1348
1349    pub async fn is_session_initialized(&self, session_id: &str) -> bool {
1351        match self.storage.get_session(session_id).await {
1353            Ok(Some(session_info)) => {
1354                debug!(
1355                    "✅ Session {} initialization status from storage: {}",
1356                    session_id, session_info.is_initialized
1357                );
1358                session_info.is_initialized
1359            }
1360            Ok(None) => {
1361                debug!("⚠️ Session {} not found in storage", session_id);
1362                false
1363            }
1364            Err(e) => {
1365                warn!(
1366                    "⚠️ Failed to check session {} in storage: {} - falling back to cache",
1367                    session_id, e
1368                );
1369                let sessions = self.sessions.read().await;
1371                sessions
1372                    .get(session_id)
1373                    .map(|s| s.initialized)
1374                    .unwrap_or(false)
1375            }
1376        }
1377    }
1378
1379    pub async fn remove_session(&self, session_id: &str) -> bool {
1381        let storage_removed = match self.storage.delete_session(session_id).await {
1383            Ok(removed) => {
1384                if removed {
1385                    debug!("Session {} removed from storage backend", session_id);
1386                }
1387                removed
1388            }
1389            Err(e) => {
1390                error!(
1391                    "Failed to remove session {} from storage: {}",
1392                    session_id, e
1393                );
1394                false
1395            }
1396        };
1397
1398        let mut sessions = self.sessions.write().await;
1400        let memory_removed = if let Some(session) = sessions.remove(session_id) {
1401            debug!("Session {} removed from memory cache", session_id);
1402            let _ = session.send_event(SessionEvent::Disconnect);
1404            true
1405        } else {
1406            false
1407        };
1408
1409        storage_removed || memory_removed
1411    }
1412
1413    pub async fn cleanup_expired(&self) -> usize {
1415        let timeout_duration = self.session_timeout;
1416        let cutoff = std::time::SystemTime::now() - timeout_duration;
1417
1418        let storage_removed = match self.storage.expire_sessions(cutoff).await {
1420            Ok(expired_ids) => {
1421                let count = expired_ids.len();
1422                if count > 0 {
1423                    info!(
1424                        "Storage backend cleaned up {} expired sessions: {:?}",
1425                        count, expired_ids
1426                    );
1427                }
1428                count
1429            }
1430            Err(e) => {
1431                error!("Failed to clean up expired sessions from storage: {}", e);
1432                0
1433            }
1434        };
1435
1436        let cutoff_instant = Instant::now() - timeout_duration;
1438        let mut sessions = self.sessions.write().await;
1439        let initial_count = sessions.len();
1440
1441        sessions.retain(|id, session| {
1442            let keep = session.last_accessed >= cutoff_instant;
1443            if !keep {
1444                info!("Session {} expired and removed from memory cache", id);
1445                let _ = session.send_event(SessionEvent::Disconnect);
1447            }
1448            keep
1449        });
1450
1451        let memory_removed = initial_count - sessions.len();
1452
1453        std::cmp::max(storage_removed, memory_removed)
1455    }
1456
1457    pub async fn send_event_to_session(
1459        &self,
1460        session_id: &str,
1461        event: SessionEvent,
1462    ) -> Result<(), SessionError> {
1463        let sessions = self.sessions.read().await;
1464        if let Some(session) = sessions.get(session_id) {
1465            session
1467                .send_event(event.clone())
1468                .map_err(SessionError::InvalidData)?;
1469
1470            debug!(
1472                "🌐 Forwarding event to global broadcaster: session={}, event={:?}",
1473                session_id, event
1474            );
1475            if let Err(e) = self
1476                .global_event_sender
1477                .send((session_id.to_string(), event))
1478            {
1479                debug!("⚠️ Global event broadcast failed (no listeners): {}", e);
1480            } else {
1481                debug!("✅ Global event broadcast succeeded");
1482            }
1483
1484            Ok(())
1485        } else {
1486            Err(SessionError::NotFound(session_id.to_string()))
1487        }
1488    }
1489
1490    pub async fn broadcast_event(&self, event: SessionEvent) {
1492        let sessions = self.sessions.read().await;
1493        for (session_id, session) in sessions.iter() {
1494            if let Err(e) = session.send_event(event.clone()) {
1495                warn!("Failed to send event to session {}: {}", session_id, e);
1496            }
1497        }
1498    }
1499
1500    pub async fn session_count(&self) -> usize {
1502        match self.storage.session_count().await {
1504            Ok(count) => count,
1505            Err(e) => {
1506                debug!("Storage backend error for session_count: {}", e);
1507                self.sessions.read().await.len()
1509            }
1510        }
1511    }
1512
1513    pub fn create_session_context(self: &Arc<Self>, session_id: &str) -> Option<SessionContext> {
1515        let session_id = session_id.to_string();
1516        let session_manager = Arc::clone(self);
1517
1518        let get_state = {
1520            let session_manager = session_manager.clone();
1521            let session_id = session_id.clone();
1522            Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
1523                let session_manager = session_manager.clone();
1524                let session_id = session_id.clone();
1525                let key = key.to_string();
1526                Box::pin(async move { session_manager.get_session_state(&session_id, &key).await })
1527            })
1528        };
1529
1530        let set_state = {
1531            let session_manager = session_manager.clone();
1532            let session_id = session_id.clone();
1533            Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
1534                let session_manager = session_manager.clone();
1535                let session_id = session_id.clone();
1536                let key = key.to_string();
1537                Box::pin(async move {
1538                    let _ = session_manager
1539                        .set_session_state(&session_id, &key, value)
1540                        .await;
1541                })
1542            })
1543        };
1544
1545        let remove_state = {
1546            let session_manager = session_manager.clone();
1547            let session_id = session_id.clone();
1548            Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
1549                let session_manager = session_manager.clone();
1550                let session_id = session_id.clone();
1551                let key = key.to_string();
1552                Box::pin(async move {
1553                    session_manager
1554                        .remove_session_state(&session_id, &key)
1555                        .await
1556                })
1557            })
1558        };
1559
1560        let is_initialized = {
1561            let session_manager = session_manager.clone();
1562            let session_id = session_id.clone();
1563            Arc::new(move || -> BoxFuture<bool> {
1564                let session_manager = session_manager.clone();
1565                let session_id = session_id.clone();
1566                Box::pin(async move { session_manager.is_session_initialized(&session_id).await })
1567            })
1568        };
1569
1570        let send_notification = {
1571            let session_manager = session_manager.clone();
1572            let session_id = session_id.clone();
1573            Arc::new(move |event: SessionEvent| -> BoxFuture<()> {
1574                let session_manager = session_manager.clone();
1575                let session_id = session_id.clone();
1576                Box::pin(async move {
1577                    let _ = session_manager
1578                        .send_event_to_session(&session_id, event)
1579                        .await;
1580                })
1581            })
1582        };
1583
1584        Some(SessionContext {
1585            session_id,
1586            get_state,
1587            set_state,
1588            remove_state,
1589            is_initialized,
1590            send_notification,
1591            broadcaster: None, })
1593    }
1594
1595    pub fn start_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
1597        let manager = Arc::clone(&self);
1598        tokio::spawn(async move {
1599            let mut interval = tokio::time::interval(manager.cleanup_interval);
1600            loop {
1601                interval.tick().await;
1602                let cleaned = manager.cleanup_expired().await;
1603                if cleaned > 0 {
1604                    debug!("Cleaned up {} expired sessions", cleaned);
1605                }
1606            }
1607        })
1608    }
1609
1610    pub async fn get_session_event_receiver(
1612        &self,
1613        session_id: &str,
1614    ) -> Option<broadcast::Receiver<SessionEvent>> {
1615        let sessions = self.sessions.read().await;
1616        Some(sessions.get(session_id)?.subscribe_events())
1617    }
1618
1619    pub fn subscribe_all_session_events(&self) -> broadcast::Receiver<(String, SessionEvent)> {
1622        self.global_event_sender.subscribe()
1623    }
1624
1625    pub fn get_storage(&self) -> Arc<turul_mcp_session_storage::BoxedSessionStorage> {
1628        Arc::clone(&self.storage)
1629    }
1630
1631    pub fn get_default_capabilities(&self) -> ServerCapabilities {
1633        self.default_capabilities.clone()
1634    }
1635
1636    pub async fn session_exists_in_cache(&self, session_id: &str) -> bool {
1638        self.sessions.read().await.contains_key(session_id)
1639    }
1640}
1641
1642#[async_trait]
1644pub trait SessionAware {
1645    async fn handle_with_session(
1647        &self,
1648        params: Option<Value>,
1649        session: Option<SessionContext>,
1650    ) -> Result<Value, String>;
1651}
1652
1653#[cfg(test)]
1654mod tests {
1655    use super::*;
1656    use serde_json::json;
1657
1658    #[tokio::test]
1659    async fn test_session_creation() {
1660        let capabilities = ServerCapabilities::default();
1661        let manager = SessionManager::new(capabilities);
1662
1663        let session_id = manager.create_session().await;
1664        assert!(!session_id.is_empty());
1665        assert!(manager.session_exists(&session_id).await);
1666    }
1667
1668    #[tokio::test]
1669    async fn test_session_state() {
1670        let capabilities = ServerCapabilities::default();
1671        let manager = SessionManager::new(capabilities);
1672
1673        let session_id = manager.create_session().await;
1674
1675        manager
1677            .set_session_state(&session_id, "test_key", json!("test_value"))
1678            .await;
1679
1680        let value = manager.get_session_state(&session_id, "test_key").await;
1682        assert_eq!(value, Some(json!("test_value")));
1683
1684        let removed = manager.remove_session_state(&session_id, "test_key").await;
1686        assert_eq!(removed, Some(json!("test_value")));
1687
1688        let value = manager.get_session_state(&session_id, "test_key").await;
1690        assert_eq!(value, None);
1691    }
1692
1693    #[tokio::test]
1694    async fn test_session_context() {
1695        let capabilities = ServerCapabilities::default();
1696        let manager = Arc::new(SessionManager::new(capabilities));
1697
1698        let session_id = manager.create_session().await;
1699        let ctx = manager.create_session_context(&session_id).unwrap();
1700
1701        (ctx.set_state)("test", json!("value")).await;
1703        let value = (ctx.get_state)("test").await;
1704        assert_eq!(value, Some(json!("value")));
1705
1706        let removed = (ctx.remove_state)("test").await;
1707        assert_eq!(removed, Some(json!("value")));
1708
1709        ctx.notify_log(
1711            turul_mcp_protocol::logging::LoggingLevel::Info,
1712            serde_json::json!("Test notification"),
1713            Some("test".to_string()),
1714            None,
1715        )
1716        .await;
1717        ctx.notify_progress("test-token", 50).await;
1718    }
1719
1720    #[tokio::test]
1721    async fn test_session_expiry() {
1722        let capabilities = ServerCapabilities::default();
1723        let mut manager = SessionManager::new(capabilities);
1724        manager.session_timeout = Duration::from_millis(100); let session_id = manager.create_session().await;
1727        assert!(manager.session_exists_in_cache(&session_id).await);
1729
1730        tokio::time::sleep(Duration::from_millis(150)).await;
1732
1733        let result = manager.touch_session(&session_id).await;
1735        assert!(matches!(result, Err(SessionError::Expired(_))));
1736    }
1737}