Skip to main content

turbomcp_protocol/context/
rich.rs

1//! Rich context extension traits for enhanced tool capabilities.
2//!
3//! This module provides extension traits that add advanced capabilities to
4//! `RequestContext`, including:
5//!
6//! - Session state management (`get_state`, `set_state`)
7//! - Client logging (`info`, `debug`, `warning`, `error`)
8//! - Progress reporting (`report_progress`)
9//!
10//! # Memory Management
11//!
12//! Session state is stored in a process-level map keyed by session ID.
13//! **IMPORTANT**: You must ensure cleanup happens when sessions end to prevent
14//! memory leaks. Use one of these approaches:
15//!
16//! 1. **Recommended**: Use [`SessionStateGuard`] which automatically cleans up on drop
17//! 2. **Manual**: Call [`cleanup_session_state`] when a session disconnects
18//!
19//! # Client Logging and Progress
20//!
21//! The logging and progress methods require bidirectional transport support.
22//! They will silently succeed (no-op) if the transport doesn't support
23//! server-to-client notifications.
24//!
25//! # Example
26//!
27//! ```rust,ignore
28//! use turbomcp_protocol::context::{RichContextExt, SessionStateGuard};
29//!
30//! async fn handle_session(session_id: String) {
31//!     // Guard ensures cleanup when it goes out of scope
32//!     let _guard = SessionStateGuard::new(&session_id);
33//!
34//!     let ctx = RequestContext::new().with_session_id(&session_id);
35//!     ctx.set_state("counter", &0i32);
36//!
37//!     // Client logging
38//!     ctx.info("Starting processing...").await;
39//!
40//!     // Progress reporting
41//!     for i in 0..100 {
42//!         ctx.report_progress(i, 100, Some(&format!("Step {}", i))).await;
43//!     }
44//!
45//!     ctx.info("Processing complete!").await;
46//!
47//! } // Guard dropped here, session state automatically cleaned up
48//! ```
49
50use std::collections::HashMap;
51use std::sync::Arc;
52
53use parking_lot::RwLock;
54use serde::{Serialize, de::DeserializeOwned};
55use serde_json::Value;
56
57use crate::McpError;
58use crate::types::{LogLevel, LoggingNotification, ProgressNotification, ServerNotification};
59
60use super::request::RequestContext;
61
62/// Type alias for session state storage to reduce complexity.
63type SessionStateMap = dashmap::DashMap<String, Arc<RwLock<HashMap<String, Value>>>>;
64
65/// Session state storage (keyed by session_id).
66///
67/// **Warning**: This is a process-level singleton. Entries must be manually
68/// cleaned up via [`cleanup_session_state`] or [`SessionStateGuard`] to
69/// prevent unbounded memory growth.
70///
71/// # Multi-Server Considerations
72///
73/// If you run multiple MCP servers in the same process (e.g., in tests or
74/// composite server scenarios), be aware that session IDs may collide.
75/// To avoid this:
76///
77/// 1. Use unique session ID prefixes per server: `"{server_name}:{session_id}"`
78/// 2. Or ensure each server uses globally unique session IDs (e.g., UUIDs)
79///
80/// This global singleton design enables session state to be shared across
81/// handler invocations without threading server references through the
82/// entire call chain.
83static SESSION_STATE: std::sync::LazyLock<SessionStateMap> =
84    std::sync::LazyLock::new(SessionStateMap::new);
85
86/// RAII guard that automatically cleans up session state when dropped.
87///
88/// This is the recommended way to manage session state lifetime. Create a guard
89/// at the start of a session and let it clean up automatically when the session
90/// ends.
91///
92/// # Example
93///
94/// ```rust,ignore
95/// use turbomcp_protocol::context::SessionStateGuard;
96///
97/// async fn handle_connection(session_id: String) {
98///     let _guard = SessionStateGuard::new(&session_id);
99///
100///     // Session state is available for this session_id
101///     // ...
102///
103/// } // State automatically cleaned up here
104/// ```
105#[derive(Debug)]
106pub struct SessionStateGuard {
107    session_id: String,
108}
109
110impl SessionStateGuard {
111    /// Create a new session state guard.
112    ///
113    /// The session's state will be automatically cleaned up when this guard
114    /// is dropped.
115    pub fn new(session_id: impl Into<String>) -> Self {
116        Self {
117            session_id: session_id.into(),
118        }
119    }
120
121    /// Get the session ID this guard is managing.
122    pub fn session_id(&self) -> &str {
123        &self.session_id
124    }
125}
126
127impl Drop for SessionStateGuard {
128    fn drop(&mut self) {
129        cleanup_session_state(&self.session_id);
130    }
131}
132
133/// Error type for state operations.
134#[derive(Debug, Clone, PartialEq, Eq)]
135pub enum StateError {
136    /// No session ID is set on the context.
137    NoSessionId,
138    /// Failed to serialize the value.
139    SerializationFailed(String),
140    /// Failed to deserialize the value.
141    DeserializationFailed(String),
142}
143
144impl std::fmt::Display for StateError {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            Self::NoSessionId => write!(f, "no session ID set on context"),
148            Self::SerializationFailed(e) => write!(f, "serialization failed: {}", e),
149            Self::DeserializationFailed(e) => write!(f, "deserialization failed: {}", e),
150        }
151    }
152}
153
154impl std::error::Error for StateError {}
155
156/// Extension trait providing rich context capabilities.
157///
158/// This trait extends `RequestContext` with session state management,
159/// client logging, and progress reporting.
160pub trait RichContextExt {
161    // ===== State Management =====
162
163    /// Get a value from session state.
164    ///
165    /// Returns `None` if the key doesn't exist or if there's no session.
166    fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T>;
167
168    /// Try to get a value from session state with detailed error information.
169    ///
170    /// Returns `Err` if there's no session ID or deserialization fails.
171    /// Returns `Ok(None)` if the key doesn't exist.
172    fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError>;
173
174    /// Set a value in session state.
175    ///
176    /// Returns `false` if there's no session ID to store state against.
177    fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool;
178
179    /// Try to set a value in session state with detailed error information.
180    fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError>;
181
182    /// Remove a value from session state.
183    fn remove_state(&self, key: &str) -> bool;
184
185    /// Clear all session state.
186    fn clear_state(&self);
187
188    /// Check if a state key exists.
189    fn has_state(&self, key: &str) -> bool;
190
191    // ===== Client Logging =====
192
193    /// Send a debug-level log message to the client.
194    ///
195    /// Returns `Ok(())` if the notification was sent or if bidirectional
196    /// transport is not available (no-op in that case).
197    fn debug(
198        &self,
199        message: impl Into<String> + Send,
200    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
201
202    /// Send an info-level log message to the client.
203    ///
204    /// Returns `Ok(())` if the notification was sent or if bidirectional
205    /// transport is not available (no-op in that case).
206    fn info(
207        &self,
208        message: impl Into<String> + Send,
209    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
210
211    /// Send a warning-level log message to the client.
212    ///
213    /// Returns `Ok(())` if the notification was sent or if bidirectional
214    /// transport is not available (no-op in that case).
215    fn warning(
216        &self,
217        message: impl Into<String> + Send,
218    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
219
220    /// Send an error-level log message to the client.
221    ///
222    /// Returns `Ok(())` if the notification was sent or if bidirectional
223    /// transport is not available (no-op in that case).
224    fn error(
225        &self,
226        message: impl Into<String> + Send,
227    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
228
229    /// Send a log message to the client with a specific level.
230    ///
231    /// This is the low-level method that `debug`, `info`, `warning`, and `error`
232    /// delegate to. Use this when you need fine-grained control over the log level.
233    ///
234    /// Returns `Ok(())` if the notification was sent or if bidirectional
235    /// transport is not available (no-op in that case).
236    fn log(
237        &self,
238        level: LogLevel,
239        message: impl Into<String> + Send,
240        logger: Option<String>,
241    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
242
243    // ===== Progress Reporting =====
244
245    /// Report progress on a long-running operation.
246    ///
247    /// # Arguments
248    ///
249    /// * `current` - Current progress value
250    /// * `total` - Total value (for percentage: current/total * 100)
251    /// * `message` - Optional status message
252    ///
253    /// # Example
254    ///
255    /// ```rust,ignore
256    /// for i in 0..100 {
257    ///     ctx.report_progress(i, 100, Some(&format!("Processing item {}", i))).await?;
258    /// }
259    /// ```
260    fn report_progress(
261        &self,
262        current: u64,
263        total: u64,
264        message: Option<&str>,
265    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
266
267    /// Report progress with custom progress token.
268    ///
269    /// Use this when you need to track multiple concurrent operations
270    /// with different progress tokens.
271    fn report_progress_with_token(
272        &self,
273        token: impl Into<String> + Send,
274        current: u64,
275        total: Option<u64>,
276        message: Option<&str>,
277    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
278}
279
280impl RichContextExt for RequestContext {
281    fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
282        self.try_get_state(key).ok().flatten()
283    }
284
285    fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
286        let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
287
288        let Some(state) = SESSION_STATE.get(session_id) else {
289            return Ok(None);
290        };
291
292        let state_read = state.read();
293        let Some(value) = state_read.get(key) else {
294            return Ok(None);
295        };
296
297        serde_json::from_value(value.clone())
298            .map(Some)
299            .map_err(|e| StateError::DeserializationFailed(e.to_string()))
300    }
301
302    fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool {
303        self.try_set_state(key, value).is_ok()
304    }
305
306    fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
307        let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
308
309        let json_value = serde_json::to_value(value)
310            .map_err(|e| StateError::SerializationFailed(e.to_string()))?;
311
312        let state = SESSION_STATE
313            .entry(session_id.clone())
314            .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
315
316        state.write().insert(key.to_string(), json_value);
317        Ok(())
318    }
319
320    fn remove_state(&self, key: &str) -> bool {
321        let Some(ref session_id) = self.session_id else {
322            return false;
323        };
324
325        if let Some(state) = SESSION_STATE.get(session_id) {
326            state.write().remove(key);
327            return true;
328        }
329        false
330    }
331
332    fn clear_state(&self) {
333        if let Some(ref session_id) = self.session_id
334            && let Some(state) = SESSION_STATE.get(session_id)
335        {
336            state.write().clear();
337        }
338    }
339
340    fn has_state(&self, key: &str) -> bool {
341        if let Some(ref session_id) = self.session_id
342            && let Some(state) = SESSION_STATE.get(session_id)
343        {
344            return state.read().contains_key(key);
345        }
346        false
347    }
348
349    // ===== Client Logging =====
350
351    async fn debug(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
352        self.log(LogLevel::Debug, message, None).await
353    }
354
355    async fn info(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
356        self.log(LogLevel::Info, message, None).await
357    }
358
359    async fn warning(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
360        self.log(LogLevel::Warning, message, None).await
361    }
362
363    async fn error(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
364        self.log(LogLevel::Error, message, None).await
365    }
366
367    async fn log(
368        &self,
369        level: LogLevel,
370        message: impl Into<String> + Send,
371        logger: Option<String>,
372    ) -> Result<(), McpError> {
373        // If no server_to_client available, this is a no-op (not an error)
374        let Some(s2c) = self.server_to_client() else {
375            return Ok(());
376        };
377
378        let notification = ServerNotification::Message(LoggingNotification {
379            level,
380            data: serde_json::Value::String(message.into()),
381            logger,
382        });
383
384        s2c.send_notification(notification).await
385    }
386
387    // ===== Progress Reporting =====
388
389    async fn report_progress(
390        &self,
391        current: u64,
392        total: u64,
393        message: Option<&str>,
394    ) -> Result<(), McpError> {
395        // Use request_id as the progress token by default
396        self.report_progress_with_token(&self.request_id, current, Some(total), message)
397            .await
398    }
399
400    async fn report_progress_with_token(
401        &self,
402        token: impl Into<String> + Send,
403        current: u64,
404        total: Option<u64>,
405        message: Option<&str>,
406    ) -> Result<(), McpError> {
407        // If no server_to_client available, this is a no-op (not an error)
408        let Some(s2c) = self.server_to_client() else {
409            return Ok(());
410        };
411
412        let notification = ServerNotification::Progress(ProgressNotification {
413            progress_token: token.into(),
414            progress: current,
415            total,
416            message: message.map(ToString::to_string),
417        });
418
419        s2c.send_notification(notification).await
420    }
421}
422
423/// Clean up session state when a session ends.
424///
425/// **Important**: Call this when a session disconnects to free memory.
426/// Alternatively, use [`SessionStateGuard`] for automatic cleanup.
427///
428/// # Example
429///
430/// ```rust,ignore
431/// use turbomcp_protocol::context::cleanup_session_state;
432///
433/// fn on_session_disconnect(session_id: &str) {
434///     cleanup_session_state(session_id);
435/// }
436/// ```
437pub fn cleanup_session_state(session_id: &str) {
438    SESSION_STATE.remove(session_id);
439}
440
441/// Get the number of active sessions with state.
442///
443/// This is useful for monitoring memory usage.
444pub fn active_sessions_count() -> usize {
445    SESSION_STATE.len()
446}
447
448/// Clear all session state.
449///
450/// **Warning**: This removes state for ALL sessions. Use with caution.
451/// Primarily intended for testing.
452#[cfg(test)]
453pub fn clear_all_session_state() {
454    SESSION_STATE.clear();
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_get_set_state() {
463        let ctx = RequestContext::new().with_session_id("test-session-1");
464
465        // Set state
466        assert!(ctx.set_state("counter", &42i32));
467        assert!(ctx.set_state("name", &"Alice".to_string()));
468
469        // Get state
470        assert_eq!(ctx.get_state::<i32>("counter"), Some(42));
471        assert_eq!(ctx.get_state::<String>("name"), Some("Alice".to_string()));
472        assert_eq!(ctx.get_state::<i32>("missing"), None);
473
474        // Has state
475        assert!(ctx.has_state("counter"));
476        assert!(!ctx.has_state("missing"));
477
478        // Remove state
479        assert!(ctx.remove_state("counter"));
480        assert_eq!(ctx.get_state::<i32>("counter"), None);
481        assert!(!ctx.has_state("counter"));
482
483        // Clear state
484        ctx.clear_state();
485        assert_eq!(ctx.get_state::<String>("name"), None);
486
487        // Cleanup
488        cleanup_session_state("test-session-1");
489    }
490
491    #[test]
492    fn test_state_without_session() {
493        let ctx = RequestContext::new();
494
495        // Without session_id, state operations fail
496        assert!(!ctx.set_state("key", &"value"));
497        assert_eq!(ctx.get_state::<String>("key"), None);
498        assert!(!ctx.has_state("key"));
499
500        // try_* methods return proper errors
501        assert_eq!(
502            ctx.try_set_state("key", &"value"),
503            Err(StateError::NoSessionId)
504        );
505        assert_eq!(
506            ctx.try_get_state::<String>("key"),
507            Err(StateError::NoSessionId)
508        );
509    }
510
511    #[test]
512    fn test_state_isolation() {
513        let ctx1 = RequestContext::new().with_session_id("session-iso-1");
514        let ctx2 = RequestContext::new().with_session_id("session-iso-2");
515
516        // Set different values in different sessions
517        ctx1.set_state("value", &1i32);
518        ctx2.set_state("value", &2i32);
519
520        // Each session sees its own value
521        assert_eq!(ctx1.get_state::<i32>("value"), Some(1));
522        assert_eq!(ctx2.get_state::<i32>("value"), Some(2));
523
524        // Cleanup
525        cleanup_session_state("session-iso-1");
526        cleanup_session_state("session-iso-2");
527    }
528
529    #[test]
530    fn test_complex_types() {
531        let ctx = RequestContext::new().with_session_id("complex-session-1");
532
533        #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
534        struct MyData {
535            count: i32,
536            items: Vec<String>,
537        }
538
539        let data = MyData {
540            count: 3,
541            items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
542        };
543
544        ctx.set_state("data", &data);
545        let retrieved: Option<MyData> = ctx.get_state("data");
546        assert_eq!(retrieved, Some(data));
547
548        cleanup_session_state("complex-session-1");
549    }
550
551    #[test]
552    fn test_session_state_guard() {
553        let session_id = "guard-test-session";
554
555        {
556            let _guard = SessionStateGuard::new(session_id);
557            let ctx = RequestContext::new().with_session_id(session_id);
558
559            ctx.set_state("key", &"value");
560            assert_eq!(ctx.get_state::<String>("key"), Some("value".to_string()));
561
562            // State exists while guard is alive
563            assert!(active_sessions_count() > 0);
564        }
565
566        // After guard drops, state should be cleaned up
567        let ctx = RequestContext::new().with_session_id(session_id);
568        assert_eq!(ctx.get_state::<String>("key"), None);
569    }
570
571    #[test]
572    fn test_try_get_state_errors() {
573        let ctx = RequestContext::new().with_session_id("error-test-session");
574        ctx.set_state("number", &42i32);
575
576        // Type mismatch returns deserialization error
577        let result: Result<Option<String>, StateError> = ctx.try_get_state("number");
578        assert!(matches!(result, Err(StateError::DeserializationFailed(_))));
579
580        cleanup_session_state("error-test-session");
581    }
582
583    #[test]
584    fn test_state_error_display() {
585        assert_eq!(
586            StateError::NoSessionId.to_string(),
587            "no session ID set on context"
588        );
589        assert!(
590            StateError::SerializationFailed("test".into())
591                .to_string()
592                .contains("serialization failed")
593        );
594        assert!(
595            StateError::DeserializationFailed("test".into())
596                .to_string()
597                .contains("deserialization failed")
598        );
599    }
600
601    #[tokio::test]
602    async fn test_logging_without_server_to_client() {
603        // Without server_to_client configured, logging methods should be no-ops
604        let ctx = RequestContext::new().with_session_id("logging-test");
605
606        // These should all succeed (no-op) without server_to_client
607        assert!(ctx.debug("debug message").await.is_ok());
608        assert!(ctx.info("info message").await.is_ok());
609        assert!(ctx.warning("warning message").await.is_ok());
610        assert!(ctx.error("error message").await.is_ok());
611        assert!(ctx.log(LogLevel::Notice, "notice", None).await.is_ok());
612    }
613
614    #[tokio::test]
615    async fn test_progress_without_server_to_client() {
616        // Without server_to_client configured, progress methods should be no-ops
617        let ctx = RequestContext::new().with_session_id("progress-test");
618
619        // These should all succeed (no-op) without server_to_client
620        assert!(ctx.report_progress(50, 100, Some("halfway")).await.is_ok());
621        assert!(ctx.report_progress(100, 100, None).await.is_ok());
622        assert!(
623            ctx.report_progress_with_token("custom-token", 25, Some(100), Some("processing"))
624                .await
625                .is_ok()
626        );
627    }
628}