1use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use async_trait::async_trait;
19use serde_json::Value;
20use tokio::sync::{RwLock, broadcast};
21use tracing::{debug, error, info, warn};
22use uuid::Uuid;
23
24use turul_mcp_protocol::{ClientCapabilities, Implementation, McpVersion, ServerCapabilities};
25use turul_mcp_session_storage::{SessionStorage, SessionStorageError, SessionView};
26
27type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
28
29type GetStateFn = Arc<dyn Fn(&str) -> BoxFuture<Option<Value>> + Send + Sync>;
44type SetStateFn = Arc<dyn Fn(&str, Value) -> BoxFuture<()> + Send + Sync>;
45type RemoveStateFn = Arc<dyn Fn(&str) -> BoxFuture<Option<Value>> + Send + Sync>;
46
47#[derive(Clone)]
48pub struct SessionContext {
49 pub session_id: String,
51 pub get_state: GetStateFn,
53 pub set_state: SetStateFn,
55 pub remove_state: RemoveStateFn,
57 pub is_initialized: Arc<dyn Fn() -> BoxFuture<bool> + Send + Sync>,
59 pub send_notification: Arc<dyn Fn(SessionEvent) -> BoxFuture<()> + Send + Sync>,
61 pub broadcaster: Option<Arc<dyn std::any::Any + Send + Sync>>,
63}
64
65impl SessionContext {
66 pub(crate) fn from_json_rpc_with_broadcaster(
68 json_rpc_ctx: turul_mcp_json_rpc_server::SessionContext,
69 storage: Arc<dyn SessionStorage<Error = SessionStorageError>>,
70 ) -> Self {
71 let session_id = json_rpc_ctx.session_id.clone();
72 let broadcaster = json_rpc_ctx.broadcaster.clone();
73
74 let get_state = {
76 let storage = storage.clone();
77 let session_id = session_id.clone();
78 Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
79 let storage = storage.clone();
80 let session_id = session_id.clone();
81 let key = key.to_string();
82 Box::pin(async move {
83 match storage.get_session_state(&session_id, &key).await {
84 Ok(Some(value)) => Some(value),
85 Ok(None) => None,
86 Err(e) => {
87 tracing::warn!("Failed to get session state for key '{}': {}", key, e);
88 None
89 }
90 }
91 })
92 })
93 };
94
95 let set_state = {
96 let storage = storage.clone();
97 let session_id = session_id.clone();
98 Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
99 let storage = storage.clone();
100 let session_id = session_id.clone();
101 let key = key.to_string();
102 Box::pin(async move {
103 if let Err(e) = storage.set_session_state(&session_id, &key, value).await {
104 tracing::error!("Failed to set session state for key '{}': {}", key, e);
105 }
106 })
107 })
108 };
109
110 let remove_state = {
111 let storage = storage.clone();
112 let session_id = session_id.clone();
113 Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
114 let storage = storage.clone();
115 let session_id = session_id.clone();
116 let key = key.to_string();
117 Box::pin(async move {
118 match storage.remove_session_state(&session_id, &key).await {
119 Ok(value) => value,
120 Err(e) => {
121 tracing::warn!(
122 "Failed to remove session state for key '{}': {}",
123 key,
124 e
125 );
126 None
127 }
128 }
129 })
130 })
131 };
132
133 let is_initialized = {
134 let storage = storage.clone();
135 let session_id = session_id.clone();
136 Arc::new(move || -> BoxFuture<bool> {
137 let storage = storage.clone();
138 let session_id = session_id.clone();
139 Box::pin(async move {
140 match storage.get_session(&session_id).await {
141 Ok(Some(session_info)) => session_info.is_initialized,
142 Ok(None) => {
143 tracing::warn!("Session {} not found in storage", session_id);
144 false
145 }
146 Err(e) => {
147 tracing::error!("Failed to check session initialization: {}", e);
148 false
149 }
150 }
151 })
152 })
153 };
154
155 let send_notification = {
157 let session_id = session_id.clone();
158 let broadcaster = broadcaster.clone();
159 Arc::new(move |event: SessionEvent| -> BoxFuture<()> {
160 let session_id = session_id.clone();
161 let broadcaster = broadcaster.clone();
162 Box::pin(async move {
163 debug!(
164 "📨 SessionContext.send_notification() called for session {}: {:?}",
165 session_id, event
166 );
167
168 if let Some(broadcaster_any) = &broadcaster {
170 debug!(
171 "✅ NotificationBroadcaster available for session: {}",
172 session_id
173 );
174
175 match event {
177 SessionEvent::Notification(json_value) => {
178 debug!(
179 "🔧 Attempting to send notification via StreamManagerNotificationBroadcaster"
180 );
181 debug!("📦 Notification JSON: {}", json_value);
182
183 match parse_and_send_notification_with_broadcaster(
185 &session_id,
186 &json_value,
187 broadcaster_any,
188 )
189 .await
190 {
191 Ok(_) => debug!(
192 "✅ Bridge working: Successfully processed notification for session {}",
193 session_id
194 ),
195 Err(e) => error!(
196 "❌ Bridge error: Failed to process notification for session {}: {}",
197 session_id, e
198 ),
199 }
200 }
201 _ => {
202 debug!("⚠️ Non-notification event, ignoring: {:?}", event);
203 }
204 }
205 } else {
206 debug!("⚠️ No broadcaster available for session {}", session_id);
207 }
208 })
209 })
210 };
211
212 SessionContext {
213 session_id,
214 get_state,
215 set_state,
216 remove_state,
217 is_initialized,
218 send_notification,
219 broadcaster,
220 }
221 }
222
223 pub fn has_broadcaster(&self) -> bool {
225 self.broadcaster.is_some()
226 }
227
228 pub fn get_raw_broadcaster(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
230 self.broadcaster.clone()
231 }
232
233 #[cfg(feature = "test-utils")]
235 pub fn from_json_rpc_with_broadcaster_for_tests(
236 json_rpc_ctx: turul_mcp_json_rpc_server::SessionContext,
237 storage: Arc<dyn SessionStorage<Error = SessionStorageError>>,
238 ) -> Self {
239 Self::from_json_rpc_with_broadcaster(json_rpc_ctx, storage)
240 }
241
242 pub async fn get_typed_state<T>(&self, key: &str) -> Option<T>
244 where
245 T: serde::de::DeserializeOwned,
246 {
247 (self.get_state)(key)
248 .await
249 .and_then(|v| serde_json::from_value(v).ok())
250 }
251
252 pub async fn set_typed_state<T>(&self, key: &str, value: T) -> Result<(), String>
254 where
255 T: serde::Serialize,
256 {
257 match serde_json::to_value(value) {
258 Ok(json_value) => {
259 (self.set_state)(key, json_value).await;
260 Ok(())
261 }
262 Err(e) => Err(format!("Failed to serialize value: {}", e)),
263 }
264 }
265
266 #[cfg(test)]
268 pub fn new_test() -> Self {
269 use std::collections::HashMap;
270 use std::sync::Arc;
271 use tokio::sync::RwLock;
272
273 let state = Arc::new(RwLock::new(HashMap::<String, Value>::new()));
274
275 let get_state = {
276 let state = state.clone();
277 Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
278 let state = state.clone();
279 let key = key.to_string();
280 Box::pin(async move { state.read().await.get(&key).cloned() })
281 })
282 };
283
284 let set_state = {
285 let state = state.clone();
286 Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
287 let state = state.clone();
288 let key = key.to_string();
289 Box::pin(async move {
290 state.write().await.insert(key, value);
291 })
292 })
293 };
294
295 let remove_state = {
296 let state = state.clone();
297 Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
298 let state = state.clone();
299 let key = key.to_string();
300 Box::pin(async move { state.write().await.remove(&key) })
301 })
302 };
303
304 let is_initialized = Arc::new(|| -> BoxFuture<bool> { Box::pin(async { true }) });
305
306 let send_notification = Arc::new(|_event: SessionEvent| -> BoxFuture<()> {
307 Box::pin(async {})
308 });
309
310 SessionContext {
311 session_id: Uuid::now_v7().to_string(),
312 get_state,
313 set_state,
314 remove_state,
315 is_initialized,
316 send_notification,
317 broadcaster: None,
318 }
319 }
320
321 pub async fn notify(&self, event: SessionEvent) {
323 debug!(
324 "📨 SessionContext.notify() called for session {}: {:?}",
325 self.session_id, event
326 );
327 (self.send_notification)(event).await;
328 debug!("🚀 SessionContext.notify() send_notification closure completed");
329 }
330
331 pub async fn notify_progress(&self, progress_token: impl Into<String>, progress: u64) {
333 if self.has_broadcaster() {
334 debug!(
335 "🔔 notify_progress using NotificationBroadcaster for session: {}",
336 self.session_id
337 );
338 } else {
340 debug!(
341 "🔔 notify_progress using OLD SessionManager for session: {}",
342 self.session_id
343 );
344 }
345 let mut other = std::collections::HashMap::new();
346 other.insert(
347 "progressToken".to_string(),
348 serde_json::json!(progress_token.into()),
349 );
350 other.insert("progress".to_string(), serde_json::json!(progress));
351
352 let params = turul_mcp_protocol::RequestParams { meta: None, other };
353 let notification =
354 turul_mcp_protocol::JsonRpcNotification::new("notifications/progress".to_string())
355 .with_params(params);
356 self.notify(SessionEvent::Notification(
357 serde_json::to_value(notification).unwrap(),
358 ))
359 .await;
360 }
361
362 pub async fn notify_progress_with_total(
364 &self,
365 progress_token: impl Into<String>,
366 progress: u64,
367 total: u64,
368 ) {
369 let mut other = std::collections::HashMap::new();
370 other.insert(
371 "progressToken".to_string(),
372 serde_json::json!(progress_token.into()),
373 );
374 other.insert("progress".to_string(), serde_json::json!(progress));
375 other.insert("total".to_string(), serde_json::json!(total));
376
377 let params = turul_mcp_protocol::RequestParams { meta: None, other };
378 let notification =
379 turul_mcp_protocol::JsonRpcNotification::new("notifications/progress".to_string())
380 .with_params(params);
381 self.notify(SessionEvent::Notification(
382 serde_json::to_value(notification).unwrap(),
383 ))
384 .await;
385 }
386
387 pub async fn notify_log(
389 &self,
390 level: turul_mcp_protocol::logging::LoggingLevel,
391 data: serde_json::Value,
392 logger: Option<String>,
393 meta: Option<std::collections::HashMap<String, serde_json::Value>>,
394 ) {
395 let message_level = level;
397
398 if !self.should_log(message_level).await {
400 let threshold = self.get_logging_level().await;
401 debug!(
402 "🔕 Filtering out {:?} level message for session {} (threshold: {:?})",
403 message_level, self.session_id, threshold
404 );
405 return;
406 }
407
408 let threshold = self.get_logging_level().await;
409 debug!(
410 "📢 Sending {:?} level message to session {} (threshold: {:?})",
411 message_level, self.session_id, threshold
412 );
413
414 use turul_mcp_protocol::notifications::LoggingMessageNotification;
416 let mut notification = LoggingMessageNotification::new(message_level, data);
417
418 if let Some(logger) = logger {
420 notification = notification.with_logger(logger);
421 }
422
423 if let Some(meta) = meta {
425 notification = notification.with_meta(meta);
426 }
427
428 if self.has_broadcaster() {
429 debug!(
430 "🔔 notify_log using NotificationBroadcaster for session: {}",
431 self.session_id
432 );
433 self.notify(SessionEvent::Notification(
435 serde_json::to_value(notification).unwrap(),
436 ))
437 .await;
438 return;
439 } else {
440 debug!(
441 "🔔 notify_log using OLD SessionManager for session: {}",
442 self.session_id
443 );
444 }
445
446 self.notify(SessionEvent::Notification(
448 serde_json::to_value(notification).unwrap(),
449 ))
450 .await;
451 }
452
453 pub async fn notify_resources_changed(&self) {
455 let notification = turul_mcp_protocol::JsonRpcNotification::new(
456 "notifications/resources/listChanged".to_string(),
457 );
458 self.notify(SessionEvent::Notification(
459 serde_json::to_value(notification).unwrap(),
460 ))
461 .await;
462 }
463
464 pub async fn notify_resource_updated(&self, uri: impl Into<String>) {
466 let mut other = std::collections::HashMap::new();
467 other.insert("uri".to_string(), serde_json::json!(uri.into()));
468
469 let params = turul_mcp_protocol::RequestParams { meta: None, other };
470 let notification = turul_mcp_protocol::JsonRpcNotification::new(
471 "notifications/resources/updated".to_string(),
472 )
473 .with_params(params);
474 self.notify(SessionEvent::Notification(
475 serde_json::to_value(notification).unwrap(),
476 ))
477 .await;
478 }
479
480 pub async fn notify_tools_changed(&self) {
482 let notification = turul_mcp_protocol::JsonRpcNotification::new(
483 "notifications/tools/listChanged".to_string(),
484 );
485 self.notify(SessionEvent::Notification(
486 serde_json::to_value(notification).unwrap(),
487 ))
488 .await;
489 }
490
491 pub async fn get_logging_level(&self) -> turul_mcp_protocol::logging::LoggingLevel {
497 use turul_mcp_protocol::logging::LoggingLevel;
498
499 if let Some(level_value) = (self.get_state)("mcp:logging:level").await {
501 if let Some(level_str) = level_value.as_str() {
502 match level_str {
503 "debug" => LoggingLevel::Debug,
504 "info" => LoggingLevel::Info,
505 "notice" => LoggingLevel::Notice,
506 "warning" => LoggingLevel::Warning,
507 "error" => LoggingLevel::Error,
508 "critical" => LoggingLevel::Critical,
509 "alert" => LoggingLevel::Alert,
510 "emergency" => LoggingLevel::Emergency,
511 _ => LoggingLevel::Info, }
513 } else {
514 LoggingLevel::Info }
516 } else {
517 LoggingLevel::Info }
519 }
520
521 pub async fn set_logging_level(&self, level: turul_mcp_protocol::logging::LoggingLevel) {
523 use turul_mcp_protocol::logging::LoggingLevel;
524
525 let level_str = match level {
526 LoggingLevel::Debug => "debug",
527 LoggingLevel::Info => "info",
528 LoggingLevel::Notice => "notice",
529 LoggingLevel::Warning => "warning",
530 LoggingLevel::Error => "error",
531 LoggingLevel::Critical => "critical",
532 LoggingLevel::Alert => "alert",
533 LoggingLevel::Emergency => "emergency",
534 };
535
536 (self.set_state)("mcp:logging:level", serde_json::json!(level_str)).await;
537 debug!(
538 "🎯 Set logging level for session {}: {:?}",
539 self.session_id, level
540 );
541 }
542
543 pub async fn should_log(
545 &self,
546 message_level: turul_mcp_protocol::logging::LoggingLevel,
547 ) -> bool {
548 let session_threshold = self.get_logging_level().await;
549 message_level.should_log(session_threshold)
550 }
551
552 pub fn should_log_sync(
554 &self,
555 message_level: turul_mcp_protocol::logging::LoggingLevel,
556 ) -> bool {
557 let session_level = futures::executor::block_on(self.get_logging_level());
559 message_level.should_log(session_level)
560 }
561}
562
563#[async_trait]
573impl SessionView for SessionContext {
574 fn session_id(&self) -> &str {
575 &self.session_id
576 }
577
578 async fn get_state(&self, key: &str) -> Result<Option<Value>, String> {
579 Ok((self.get_state)(key).await)
580 }
581
582 async fn set_state(&self, key: &str, value: Value) -> Result<(), String> {
583 (self.set_state)(key, value).await;
584 Ok(())
585 }
586
587 async fn get_metadata(&self, key: &str) -> Result<Option<Value>, String> {
588 let metadata_key = format!("__meta__:{}", key);
590 Ok((self.get_state)(&metadata_key).await)
591 }
592
593 async fn set_metadata(&self, key: &str, value: Value) -> Result<(), String> {
594 let metadata_key = format!("__meta__:{}", key);
596 (self.set_state)(&metadata_key, value).await;
597 Ok(())
598 }
599}
600
601impl turul_mcp_builders::logging::LoggingTarget for SessionContext {
607 fn should_log(&self, level: turul_mcp_protocol::logging::LoggingLevel) -> bool {
608 self.should_log_sync(level)
609 }
610
611 fn notify_log(
612 &self,
613 level: turul_mcp_protocol::logging::LoggingLevel,
614 data: serde_json::Value,
615 logger: Option<String>,
616 meta: Option<std::collections::HashMap<String, serde_json::Value>>,
617 ) {
618 let session_ctx = self.clone();
620 tokio::spawn(async move {
621 session_ctx.notify_log(level, data, logger, meta).await;
622 });
623 }
624}
625
626async fn parse_and_send_notification_with_broadcaster(
628 session_id: &str,
629 json_value: &Value,
630 broadcaster_any: &Arc<dyn std::any::Any + Send + Sync>,
631) -> Result<(), String> {
632 debug!(
633 "🔍 Parsing notification JSON for session {}: {:?}",
634 session_id, json_value
635 );
636
637 use turul_http_mcp_server::notification_bridge::SharedNotificationBroadcaster;
639 use turul_mcp_protocol::notifications::{LoggingMessageNotification, ProgressNotification};
640 debug!(
642 "🔍 Attempting downcast for session {}, broadcaster type: {:?}",
643 session_id,
644 std::any::type_name::<SharedNotificationBroadcaster>()
645 );
646 if let Some(broadcaster) = broadcaster_any.downcast_ref::<SharedNotificationBroadcaster>() {
647 debug!(
648 "✅ Successfully downcast broadcaster for session {}",
649 session_id
650 );
651
652 if let Some(method) = json_value.get("method").and_then(|v| v.as_str()) {
654 match method {
655 "notifications/message" => {
656 debug!(
657 "📝 Message notification detected, deserializing directly to LoggingMessageNotification"
658 );
659
660 match serde_json::from_value::<LoggingMessageNotification>(json_value.clone()) {
662 Ok(notification) => {
663 debug!(
664 "✅ Successfully deserialized LoggingMessageNotification: level={:?}, logger={:?}",
665 notification.params.level, notification.params.logger
666 );
667
668 debug!(
669 "🔧 About to call broadcaster.send_message_notification() for session {}",
670 session_id
671 );
672 match broadcaster
674 .send_message_notification(session_id, notification)
675 .await
676 {
677 Ok(()) => {
678 debug!(
679 "🎉 SUCCESS: LoggingMessageNotification sent to StreamManager for session {}",
680 session_id
681 );
682 debug!(
683 "🚀 Streamable HTTP Transport Bridge: Complete end-to-end delivery confirmed!"
684 );
685 return Ok(());
686 }
687 Err(e) => {
688 error!(
689 "❌ Failed to send LoggingMessageNotification to StreamManager: {}",
690 e
691 );
692 return Err(format!(
693 "Failed to send LoggingMessageNotification: {}",
694 e
695 ));
696 }
697 }
698 }
699 Err(e) => {
700 error!("❌ Failed to deserialize LoggingMessageNotification: {}", e);
701 return Err(format!(
702 "Failed to deserialize LoggingMessageNotification: {}",
703 e
704 ));
705 }
706 }
707 }
708 "notifications/progress" => {
709 if let Some(params) = json_value.get("params")
710 && let Some(token) = params.get("progressToken").and_then(|v| v.as_str())
711 {
712 debug!("📊 Progress notification detected: token={}", token);
713
714 let progress = params.get("progress").and_then(|v| v.as_u64()).unwrap_or(0);
716
717 let notification = ProgressNotification {
719 method: "notifications/progress".to_string(),
720 params: turul_mcp_protocol::notifications::ProgressNotificationParams {
721 progress_token: token.to_string(),
722 progress,
723 total: params.get("total").and_then(|v| v.as_u64()),
724 message: params
725 .get("message")
726 .and_then(|v| v.as_str())
727 .map(|s| s.to_string()),
728 meta: None,
729 },
730 };
731
732 debug!(
733 "🔧 About to call broadcaster.send_progress_notification() for session {}",
734 session_id
735 );
736 match broadcaster
738 .send_progress_notification(session_id, notification)
739 .await
740 {
741 Ok(()) => {
742 debug!(
743 "🎉 SUCCESS: ProgressNotification sent to StreamManager for session {}",
744 session_id
745 );
746 debug!(
747 "🚀 Streamable HTTP Transport Bridge: Complete end-to-end delivery confirmed!"
748 );
749 return Ok(());
750 }
751 Err(e) => {
752 error!(
753 "❌ Failed to send ProgressNotification to StreamManager: {}",
754 e
755 );
756 return Err(format!("Failed to send ProgressNotification: {}", e));
757 }
758 }
759 }
760 }
761 _ => {
762 debug!(
763 "🔧 Other notification method: {} - sending as generic JsonRpcNotification",
764 method
765 );
766
767 let params_map: std::collections::HashMap<String, serde_json::Value> =
769 json_value
770 .get("params")
771 .and_then(|p| p.as_object())
772 .unwrap_or(&serde_json::Map::new())
773 .iter()
774 .map(|(k, v)| (k.clone(), v.clone()))
775 .collect();
776 let json_rpc_notification =
777 turul_mcp_json_rpc_server::JsonRpcNotification::new_with_object_params(
778 method.to_string(),
779 params_map,
780 );
781
782 match broadcaster
783 .send_notification(session_id, json_rpc_notification)
784 .await
785 {
786 Ok(()) => {
787 debug!(
788 "🎉 SUCCESS: Generic notification sent to StreamManager for session {}",
789 session_id
790 );
791 return Ok(());
792 }
793 Err(e) => {
794 error!(
795 "❌ Failed to send generic notification to StreamManager: {}",
796 e
797 );
798 return Err(format!("Failed to send generic notification: {}", e));
799 }
800 }
801 }
802 }
803 }
804 } else {
805 error!(
806 "❌ Failed to downcast broadcaster for session {}",
807 session_id
808 );
809 return Err("Failed to downcast broadcaster to SharedNotificationBroadcaster".to_string());
810 }
811
812 debug!(
813 "❓ Could not determine notification type for session {}",
814 session_id
815 );
816 Ok(())
817}
818
819#[derive(Debug, Clone)]
821pub enum SessionEvent {
822 Notification(Value),
824 KeepAlive,
826 Disconnect,
828 Custom { event_type: String, data: Value },
830}
831
832#[derive(Debug)]
834pub struct McpSession {
835 pub id: String,
837 pub created: Instant,
839 pub last_accessed: Instant,
841 pub mcp_version: McpVersion,
843 pub client_capabilities: Option<ClientCapabilities>,
845 pub server_capabilities: ServerCapabilities,
847 pub client_info: Option<Implementation>,
849 pub state: HashMap<String, Value>,
851 pub event_sender: broadcast::Sender<SessionEvent>,
853 pub initialized: bool,
855}
856
857impl McpSession {
858 pub fn new(server_capabilities: ServerCapabilities) -> Self {
860 let session_id = Uuid::now_v7().to_string();
861 let (event_sender, _) = broadcast::channel(128);
862
863 Self {
864 id: session_id,
865 created: Instant::now(),
866 last_accessed: Instant::now(),
867 mcp_version: McpVersion::CURRENT,
868 client_capabilities: None,
869 server_capabilities,
870 client_info: None,
871 state: HashMap::new(),
872 event_sender,
873 initialized: false,
874 }
875 }
876
877 pub fn touch(&mut self) {
879 self.last_accessed = Instant::now();
880 }
881
882 pub fn is_expired(&self, timeout: Duration) -> bool {
884 self.last_accessed.elapsed() > timeout
885 }
886
887 pub fn initialize(
889 &mut self,
890 client_info: Implementation,
891 client_capabilities: ClientCapabilities,
892 ) {
893 self.client_info = Some(client_info);
894 self.client_capabilities = Some(client_capabilities);
895 self.initialized = true;
896 self.touch();
897 }
898
899 pub fn initialize_with_version(
901 &mut self,
902 client_info: Implementation,
903 client_capabilities: ClientCapabilities,
904 mcp_version: McpVersion,
905 ) {
906 self.client_info = Some(client_info);
907 self.client_capabilities = Some(client_capabilities);
908 self.mcp_version = mcp_version;
909 self.initialized = true;
910 self.touch();
911 }
912
913 pub fn get_state(&self, key: &str) -> Option<Value> {
915 self.state.get(key).cloned()
916 }
917
918 pub fn set_state(&mut self, key: &str, value: Value) {
920 self.state.insert(key.to_string(), value);
921 self.touch();
922 }
923
924 pub fn remove_state(&mut self, key: &str) -> Option<Value> {
926 let result = self.state.remove(key);
927 if result.is_some() {
928 self.touch();
929 }
930 result
931 }
932
933 pub fn send_event(&self, event: SessionEvent) -> Result<(), String> {
935 self.event_sender
936 .send(event)
937 .map_err(|e| format!("Failed to send event: {}", e))?;
938 Ok(())
939 }
940
941 pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
943 self.event_sender.subscribe()
944 }
945}
946
947#[derive(Debug, thiserror::Error)]
949pub enum SessionError {
950 #[error("Session not found: {0}")]
951 NotFound(String),
952 #[error("Session expired: {0}")]
953 Expired(String),
954 #[error("Session not initialized: {0}")]
955 NotInitialized(String),
956 #[error("Invalid session data: {0}")]
957 InvalidData(String),
958 #[error("Storage error: {0}")]
959 StorageError(String),
960}
961
962pub struct SessionManager {
964 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
966 sessions: RwLock<HashMap<String, McpSession>>,
968 session_timeout: Duration,
970 cleanup_interval: Duration,
972 default_capabilities: ServerCapabilities,
974 global_event_sender: broadcast::Sender<(String, SessionEvent)>,
976}
977
978impl SessionManager {
979 pub fn new(default_capabilities: ServerCapabilities) -> Self {
981 let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
982 Arc::new(turul_mcp_session_storage::InMemorySessionStorage::new());
983 Self::with_storage_and_timeouts(
984 storage,
985 default_capabilities,
986 Duration::from_secs(30 * 60), Duration::from_secs(60), )
989 }
990
991 pub fn with_timeouts(
993 default_capabilities: ServerCapabilities,
994 session_timeout: Duration,
995 cleanup_interval: Duration,
996 ) -> Self {
997 let storage: Arc<turul_mcp_session_storage::BoxedSessionStorage> =
998 Arc::new(turul_mcp_session_storage::InMemorySessionStorage::new());
999 Self::with_storage_and_timeouts(
1000 storage,
1001 default_capabilities,
1002 session_timeout,
1003 cleanup_interval,
1004 )
1005 }
1006
1007 pub fn with_storage(
1009 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
1010 default_capabilities: ServerCapabilities,
1011 ) -> Self {
1012 Self::with_storage_and_timeouts(
1013 storage,
1014 default_capabilities,
1015 Duration::from_secs(30 * 60), Duration::from_secs(60), )
1018 }
1019
1020 pub fn with_storage_and_timeouts(
1022 storage: Arc<turul_mcp_session_storage::BoxedSessionStorage>,
1023 default_capabilities: ServerCapabilities,
1024 session_timeout: Duration,
1025 cleanup_interval: Duration,
1026 ) -> Self {
1027 let (global_event_sender, _) = broadcast::channel(1000);
1028
1029 Self {
1030 storage,
1031 sessions: RwLock::new(HashMap::new()),
1032 session_timeout,
1033 cleanup_interval,
1034 default_capabilities,
1035 global_event_sender,
1036 }
1037 }
1038
1039 pub async fn create_session(&self) -> String {
1041 let session = McpSession::new(self.default_capabilities.clone());
1042 let session_id = session.id.clone();
1043
1044 debug!("Creating new session: {}", session_id);
1045
1046 match self
1048 .storage
1049 .create_session_with_id(session_id.clone(), self.default_capabilities.clone())
1050 .await
1051 {
1052 Ok(_) => debug!("Session {} created in storage backend", session_id),
1053 Err(e) => error!("Failed to create session {} in storage: {}", session_id, e),
1054 }
1055
1056 self.sessions
1058 .write()
1059 .await
1060 .insert(session_id.clone(), session);
1061 session_id
1062 }
1063
1064 pub async fn create_session_with_id(&self, session_id: String) -> String {
1066 let mut session = McpSession::new(self.default_capabilities.clone());
1067 session.id = session_id.clone();
1068
1069 debug!("Creating session with provided ID: {}", session_id);
1070
1071 match self
1073 .storage
1074 .create_session_with_id(session_id.clone(), self.default_capabilities.clone())
1075 .await
1076 {
1077 Ok(_) => debug!("Session {} created in storage backend", session_id),
1078 Err(e) => error!("Failed to create session {} in storage: {}", session_id, e),
1079 }
1080
1081 self.sessions
1083 .write()
1084 .await
1085 .insert(session_id.clone(), session);
1086 session_id
1087 }
1088
1089 pub async fn add_session_to_cache(
1092 &self,
1093 session_id: String,
1094 server_capabilities: ServerCapabilities,
1095 ) {
1096 let mut session = McpSession::new(server_capabilities);
1097 session.id = session_id.clone(); debug!("Adding externally created session {} to cache", session_id);
1100 self.sessions.write().await.insert(session_id, session);
1101 }
1102
1103 pub async fn load_session_from_storage(&self, session_id: &str) -> Result<bool, SessionError> {
1106 match self.storage.get_session(session_id).await {
1107 Ok(Some(session_info)) => {
1108 debug!("Loading session {} from storage", session_id);
1109
1110 let server_capabilities =
1112 session_info.server_capabilities.clone().unwrap_or_else(|| {
1113 warn!(
1114 "Session {} in storage has no server capabilities, using defaults",
1115 session_id
1116 );
1117 self.default_capabilities.clone()
1118 });
1119
1120 let mut session = McpSession::new(server_capabilities);
1121 session.id = session_id.to_string();
1122 session.initialized = session_info.is_initialized;
1123 session.client_capabilities = session_info.client_capabilities.clone();
1124 session.state = session_info.state.clone();
1125
1126 let now = std::time::SystemTime::now()
1129 .duration_since(std::time::UNIX_EPOCH)
1130 .unwrap_or_default()
1131 .as_millis() as u64;
1132
1133 let created_elapsed = if now > session_info.created_at {
1134 Duration::from_millis(now - session_info.created_at)
1135 } else {
1136 Duration::from_secs(0)
1137 };
1138
1139 let last_activity_elapsed = if now > session_info.last_activity {
1140 Duration::from_millis(now - session_info.last_activity)
1141 } else {
1142 Duration::from_secs(0)
1143 };
1144
1145 session.created = Instant::now() - created_elapsed;
1147 session.last_accessed = Instant::now() - last_activity_elapsed;
1148
1149 self.sessions
1151 .write()
1152 .await
1153 .insert(session_id.to_string(), session);
1154
1155 debug!(
1156 "Session {} loaded from storage: initialized={}, has_capabilities={}",
1157 session_id,
1158 session_info.is_initialized,
1159 session_info.server_capabilities.is_some()
1160 );
1161
1162 Ok(true)
1163 }
1164 Ok(None) => {
1165 debug!("Session {} not found in storage", session_id);
1166 Ok(false)
1167 }
1168 Err(e) => {
1169 error!("Failed to get session {} from storage: {}", session_id, e);
1170 Err(SessionError::StorageError(e.to_string()))
1171 }
1172 }
1173 }
1174
1175 pub async fn touch_session(&self, session_id: &str) -> Result<(), SessionError> {
1177 let mut sessions = self.sessions.write().await;
1178 if let Some(session) = sessions.get_mut(session_id) {
1179 if session.is_expired(self.session_timeout) {
1180 sessions.remove(session_id);
1181 return Err(SessionError::Expired(session_id.to_string()));
1182 }
1183 session.touch();
1184 Ok(())
1185 } else {
1186 Err(SessionError::NotFound(session_id.to_string()))
1187 }
1188 }
1189
1190 pub async fn initialize_session(
1192 &self,
1193 session_id: &str,
1194 client_info: Implementation,
1195 client_capabilities: ClientCapabilities,
1196 ) -> Result<(), SessionError> {
1197 if let Ok(Some(mut session_info)) = self.storage.get_session(session_id).await {
1199 session_info.client_capabilities = Some(client_capabilities.clone());
1200 session_info.is_initialized = true;
1201 session_info.touch();
1202
1203 if let Err(e) = self.storage.update_session(session_info).await {
1204 error!("Failed to update session in storage: {}", e);
1205 }
1206 }
1207
1208 let mut sessions = self.sessions.write().await;
1210 if let Some(session) = sessions.get_mut(session_id) {
1211 session.initialize(client_info, client_capabilities);
1212 debug!("Session {} initialized", session_id);
1213 Ok(())
1214 } else {
1215 Err(SessionError::NotFound(session_id.to_string()))
1216 }
1217 }
1218
1219 pub async fn initialize_session_with_version(
1221 &self,
1222 session_id: &str,
1223 client_info: Implementation,
1224 client_capabilities: ClientCapabilities,
1225 mcp_version: McpVersion,
1226 ) -> Result<(), SessionError> {
1227 if let Ok(Some(mut session_info)) = self.storage.get_session(session_id).await {
1229 session_info.client_capabilities = Some(client_capabilities.clone());
1230 session_info.is_initialized = true;
1231 session_info.touch();
1232 if let Err(e) = self.storage.update_session(session_info).await {
1235 error!(
1236 "❌ CRITICAL: Failed to update session {} in storage: {}",
1237 session_id, e
1238 );
1239 return Err(SessionError::StorageError(format!(
1240 "Failed to persist session initialization: {}",
1241 e
1242 )));
1243 }
1244 debug!(
1245 "✅ Session {} storage updated with is_initialized=true",
1246 session_id
1247 );
1248 } else {
1249 error!(
1250 "❌ Session {} not found in storage during initialization",
1251 session_id
1252 );
1253 return Err(SessionError::NotFound(session_id.to_string()));
1254 }
1255
1256 let mut sessions = self.sessions.write().await;
1258 if let Some(session) = sessions.get_mut(session_id) {
1259 session.initialize_with_version(client_info, client_capabilities, mcp_version);
1260 debug!(
1261 "✅ Session {} cache updated with protocol version {}",
1262 session_id, mcp_version
1263 );
1264 Ok(())
1265 } else {
1266 warn!(
1267 "⚠️ Session {} not found in cache but exists in storage - creating cache entry",
1268 session_id
1269 );
1270 Ok(())
1273 }
1274 }
1275
1276 pub async fn session_exists(&self, session_id: &str) -> bool {
1278 match self.storage.get_session(session_id).await {
1280 Ok(Some(session_info)) => {
1281 let timeout_minutes = self.session_timeout.as_secs() / 60;
1283 !session_info.is_expired(timeout_minutes)
1284 }
1285 Ok(None) => false,
1286 Err(e) => {
1287 debug!("Storage backend error for session_exists: {}", e);
1288 let sessions = self.sessions.read().await;
1290 sessions
1291 .get(session_id)
1292 .map(|s| !s.is_expired(self.session_timeout))
1293 .unwrap_or(false)
1294 }
1295 }
1296 }
1297
1298 pub async fn get_session_state(&self, session_id: &str, key: &str) -> Option<Value> {
1300 match self.storage.get_session_state(session_id, key).await {
1302 Ok(value) => value,
1303 Err(e) => {
1304 debug!("Storage backend error for get_session_state: {}", e);
1305 let sessions = self.sessions.read().await;
1307 sessions.get(session_id)?.get_state(key)
1308 }
1309 }
1310 }
1311
1312 pub async fn set_session_state(&self, session_id: &str, key: &str, value: Value) {
1314 if let Err(e) = self
1316 .storage
1317 .set_session_state(session_id, key, value.clone())
1318 .await
1319 {
1320 error!("Failed to set session state in storage: {}", e);
1321 }
1322
1323 let mut sessions = self.sessions.write().await;
1325 if let Some(session) = sessions.get_mut(session_id) {
1326 session.set_state(key, value);
1327 }
1328 }
1329
1330 pub async fn remove_session_state(&self, session_id: &str, key: &str) -> Option<Value> {
1332 let storage_result = match self.storage.remove_session_state(session_id, key).await {
1334 Ok(value) => value,
1335 Err(e) => {
1336 error!("Failed to remove session state from storage: {}", e);
1337 None
1338 }
1339 };
1340
1341 let mut sessions = self.sessions.write().await;
1343 let memory_result = sessions.get_mut(session_id)?.remove_state(key);
1344
1345 storage_result.or(memory_result)
1347 }
1348
1349 pub async fn is_session_initialized(&self, session_id: &str) -> bool {
1351 match self.storage.get_session(session_id).await {
1353 Ok(Some(session_info)) => {
1354 debug!(
1355 "✅ Session {} initialization status from storage: {}",
1356 session_id, session_info.is_initialized
1357 );
1358 session_info.is_initialized
1359 }
1360 Ok(None) => {
1361 debug!("⚠️ Session {} not found in storage", session_id);
1362 false
1363 }
1364 Err(e) => {
1365 warn!(
1366 "⚠️ Failed to check session {} in storage: {} - falling back to cache",
1367 session_id, e
1368 );
1369 let sessions = self.sessions.read().await;
1371 sessions
1372 .get(session_id)
1373 .map(|s| s.initialized)
1374 .unwrap_or(false)
1375 }
1376 }
1377 }
1378
1379 pub async fn remove_session(&self, session_id: &str) -> bool {
1381 let storage_removed = match self.storage.delete_session(session_id).await {
1383 Ok(removed) => {
1384 if removed {
1385 debug!("Session {} removed from storage backend", session_id);
1386 }
1387 removed
1388 }
1389 Err(e) => {
1390 error!(
1391 "Failed to remove session {} from storage: {}",
1392 session_id, e
1393 );
1394 false
1395 }
1396 };
1397
1398 let mut sessions = self.sessions.write().await;
1400 let memory_removed = if let Some(session) = sessions.remove(session_id) {
1401 debug!("Session {} removed from memory cache", session_id);
1402 let _ = session.send_event(SessionEvent::Disconnect);
1404 true
1405 } else {
1406 false
1407 };
1408
1409 storage_removed || memory_removed
1411 }
1412
1413 pub async fn cleanup_expired(&self) -> usize {
1415 let timeout_duration = self.session_timeout;
1416 let cutoff = std::time::SystemTime::now() - timeout_duration;
1417
1418 let storage_removed = match self.storage.expire_sessions(cutoff).await {
1420 Ok(expired_ids) => {
1421 let count = expired_ids.len();
1422 if count > 0 {
1423 info!(
1424 "Storage backend cleaned up {} expired sessions: {:?}",
1425 count, expired_ids
1426 );
1427 }
1428 count
1429 }
1430 Err(e) => {
1431 error!("Failed to clean up expired sessions from storage: {}", e);
1432 0
1433 }
1434 };
1435
1436 let cutoff_instant = Instant::now() - timeout_duration;
1438 let mut sessions = self.sessions.write().await;
1439 let initial_count = sessions.len();
1440
1441 sessions.retain(|id, session| {
1442 let keep = session.last_accessed >= cutoff_instant;
1443 if !keep {
1444 info!("Session {} expired and removed from memory cache", id);
1445 let _ = session.send_event(SessionEvent::Disconnect);
1447 }
1448 keep
1449 });
1450
1451 let memory_removed = initial_count - sessions.len();
1452
1453 std::cmp::max(storage_removed, memory_removed)
1455 }
1456
1457 pub async fn send_event_to_session(
1459 &self,
1460 session_id: &str,
1461 event: SessionEvent,
1462 ) -> Result<(), SessionError> {
1463 let sessions = self.sessions.read().await;
1464 if let Some(session) = sessions.get(session_id) {
1465 session
1467 .send_event(event.clone())
1468 .map_err(SessionError::InvalidData)?;
1469
1470 debug!(
1472 "🌐 Forwarding event to global broadcaster: session={}, event={:?}",
1473 session_id, event
1474 );
1475 if let Err(e) = self
1476 .global_event_sender
1477 .send((session_id.to_string(), event))
1478 {
1479 debug!("⚠️ Global event broadcast failed (no listeners): {}", e);
1480 } else {
1481 debug!("✅ Global event broadcast succeeded");
1482 }
1483
1484 Ok(())
1485 } else {
1486 Err(SessionError::NotFound(session_id.to_string()))
1487 }
1488 }
1489
1490 pub async fn broadcast_event(&self, event: SessionEvent) {
1492 let sessions = self.sessions.read().await;
1493 for (session_id, session) in sessions.iter() {
1494 if let Err(e) = session.send_event(event.clone()) {
1495 warn!("Failed to send event to session {}: {}", session_id, e);
1496 }
1497 }
1498 }
1499
1500 pub async fn session_count(&self) -> usize {
1502 match self.storage.session_count().await {
1504 Ok(count) => count,
1505 Err(e) => {
1506 debug!("Storage backend error for session_count: {}", e);
1507 self.sessions.read().await.len()
1509 }
1510 }
1511 }
1512
1513 pub fn create_session_context(self: &Arc<Self>, session_id: &str) -> Option<SessionContext> {
1515 let session_id = session_id.to_string();
1516 let session_manager = Arc::clone(self);
1517
1518 let get_state = {
1520 let session_manager = session_manager.clone();
1521 let session_id = session_id.clone();
1522 Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
1523 let session_manager = session_manager.clone();
1524 let session_id = session_id.clone();
1525 let key = key.to_string();
1526 Box::pin(async move { session_manager.get_session_state(&session_id, &key).await })
1527 })
1528 };
1529
1530 let set_state = {
1531 let session_manager = session_manager.clone();
1532 let session_id = session_id.clone();
1533 Arc::new(move |key: &str, value: Value| -> BoxFuture<()> {
1534 let session_manager = session_manager.clone();
1535 let session_id = session_id.clone();
1536 let key = key.to_string();
1537 Box::pin(async move {
1538 let _ = session_manager
1539 .set_session_state(&session_id, &key, value)
1540 .await;
1541 })
1542 })
1543 };
1544
1545 let remove_state = {
1546 let session_manager = session_manager.clone();
1547 let session_id = session_id.clone();
1548 Arc::new(move |key: &str| -> BoxFuture<Option<Value>> {
1549 let session_manager = session_manager.clone();
1550 let session_id = session_id.clone();
1551 let key = key.to_string();
1552 Box::pin(async move {
1553 session_manager
1554 .remove_session_state(&session_id, &key)
1555 .await
1556 })
1557 })
1558 };
1559
1560 let is_initialized = {
1561 let session_manager = session_manager.clone();
1562 let session_id = session_id.clone();
1563 Arc::new(move || -> BoxFuture<bool> {
1564 let session_manager = session_manager.clone();
1565 let session_id = session_id.clone();
1566 Box::pin(async move { session_manager.is_session_initialized(&session_id).await })
1567 })
1568 };
1569
1570 let send_notification = {
1571 let session_manager = session_manager.clone();
1572 let session_id = session_id.clone();
1573 Arc::new(move |event: SessionEvent| -> BoxFuture<()> {
1574 let session_manager = session_manager.clone();
1575 let session_id = session_id.clone();
1576 Box::pin(async move {
1577 let _ = session_manager
1578 .send_event_to_session(&session_id, event)
1579 .await;
1580 })
1581 })
1582 };
1583
1584 Some(SessionContext {
1585 session_id,
1586 get_state,
1587 set_state,
1588 remove_state,
1589 is_initialized,
1590 send_notification,
1591 broadcaster: None, })
1593 }
1594
1595 pub fn start_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
1597 let manager = Arc::clone(&self);
1598 tokio::spawn(async move {
1599 let mut interval = tokio::time::interval(manager.cleanup_interval);
1600 loop {
1601 interval.tick().await;
1602 let cleaned = manager.cleanup_expired().await;
1603 if cleaned > 0 {
1604 debug!("Cleaned up {} expired sessions", cleaned);
1605 }
1606 }
1607 })
1608 }
1609
1610 pub async fn get_session_event_receiver(
1612 &self,
1613 session_id: &str,
1614 ) -> Option<broadcast::Receiver<SessionEvent>> {
1615 let sessions = self.sessions.read().await;
1616 Some(sessions.get(session_id)?.subscribe_events())
1617 }
1618
1619 pub fn subscribe_all_session_events(&self) -> broadcast::Receiver<(String, SessionEvent)> {
1622 self.global_event_sender.subscribe()
1623 }
1624
1625 pub fn get_storage(&self) -> Arc<turul_mcp_session_storage::BoxedSessionStorage> {
1628 Arc::clone(&self.storage)
1629 }
1630
1631 pub fn get_default_capabilities(&self) -> ServerCapabilities {
1633 self.default_capabilities.clone()
1634 }
1635
1636 pub async fn session_exists_in_cache(&self, session_id: &str) -> bool {
1638 self.sessions.read().await.contains_key(session_id)
1639 }
1640}
1641
1642#[async_trait]
1644pub trait SessionAware {
1645 async fn handle_with_session(
1647 &self,
1648 params: Option<Value>,
1649 session: Option<SessionContext>,
1650 ) -> Result<Value, String>;
1651}
1652
1653#[cfg(test)]
1654mod tests {
1655 use super::*;
1656 use serde_json::json;
1657
1658 #[tokio::test]
1659 async fn test_session_creation() {
1660 let capabilities = ServerCapabilities::default();
1661 let manager = SessionManager::new(capabilities);
1662
1663 let session_id = manager.create_session().await;
1664 assert!(!session_id.is_empty());
1665 assert!(manager.session_exists(&session_id).await);
1666 }
1667
1668 #[tokio::test]
1669 async fn test_session_state() {
1670 let capabilities = ServerCapabilities::default();
1671 let manager = SessionManager::new(capabilities);
1672
1673 let session_id = manager.create_session().await;
1674
1675 manager
1677 .set_session_state(&session_id, "test_key", json!("test_value"))
1678 .await;
1679
1680 let value = manager.get_session_state(&session_id, "test_key").await;
1682 assert_eq!(value, Some(json!("test_value")));
1683
1684 let removed = manager.remove_session_state(&session_id, "test_key").await;
1686 assert_eq!(removed, Some(json!("test_value")));
1687
1688 let value = manager.get_session_state(&session_id, "test_key").await;
1690 assert_eq!(value, None);
1691 }
1692
1693 #[tokio::test]
1694 async fn test_session_context() {
1695 let capabilities = ServerCapabilities::default();
1696 let manager = Arc::new(SessionManager::new(capabilities));
1697
1698 let session_id = manager.create_session().await;
1699 let ctx = manager.create_session_context(&session_id).unwrap();
1700
1701 (ctx.set_state)("test", json!("value")).await;
1703 let value = (ctx.get_state)("test").await;
1704 assert_eq!(value, Some(json!("value")));
1705
1706 let removed = (ctx.remove_state)("test").await;
1707 assert_eq!(removed, Some(json!("value")));
1708
1709 ctx.notify_log(
1711 turul_mcp_protocol::logging::LoggingLevel::Info,
1712 serde_json::json!("Test notification"),
1713 Some("test".to_string()),
1714 None,
1715 )
1716 .await;
1717 ctx.notify_progress("test-token", 50).await;
1718 }
1719
1720 #[tokio::test]
1721 async fn test_session_expiry() {
1722 let capabilities = ServerCapabilities::default();
1723 let mut manager = SessionManager::new(capabilities);
1724 manager.session_timeout = Duration::from_millis(100); let session_id = manager.create_session().await;
1727 assert!(manager.session_exists_in_cache(&session_id).await);
1729
1730 tokio::time::sleep(Duration::from_millis(150)).await;
1732
1733 let result = manager.touch_session(&session_id).await;
1735 assert!(matches!(result, Err(SessionError::Expired(_))));
1736 }
1737}