Skip to main content

turbomcp_protocol/context/
rich.rs

1//! Rich context extension traits for enhanced tool capabilities.
2//!
3//! This module provides extension traits that add advanced capabilities to
4//! `RequestContext`, including:
5//!
6//! - Session state management (`get_state`, `set_state`)
7//! - Client logging (`info`, `debug`, `warning`, `error`)
8//! - Progress reporting (`report_progress`)
9//!
10//! # Memory Management
11//!
12//! Session state is stored in a process-level map keyed by session ID.
13//! **IMPORTANT**: You must ensure cleanup happens when sessions end to prevent
14//! memory leaks. Use one of these approaches:
15//!
16//! 1. **Recommended**: Use [`SessionStateGuard`] which automatically cleans up on drop
17//! 2. **Manual**: Call [`cleanup_session_state`] when a session disconnects
18//!
19//! # Client Logging and Progress
20//!
21//! The logging and progress methods require bidirectional transport support.
22//! They will silently succeed (no-op) if the transport doesn't support
23//! server-to-client notifications.
24//!
25//! # Example
26//!
27//! ```rust,ignore
28//! use turbomcp_protocol::context::{RichContextExt, SessionStateGuard};
29//!
30//! async fn handle_session(session_id: String) {
31//!     // Guard ensures cleanup when it goes out of scope
32//!     let _guard = SessionStateGuard::new(&session_id);
33//!
34//!     let ctx = RequestContext::new().with_session_id(&session_id);
35//!     ctx.set_state("counter", &0i32);
36//!
37//!     // Client logging
38//!     ctx.info("Starting processing...").await;
39//!
40//!     // Progress reporting
41//!     for i in 0..100 {
42//!         ctx.report_progress(i, 100, Some(&format!("Step {}", i))).await;
43//!     }
44//!
45//!     ctx.info("Processing complete!").await;
46//!
47//! } // Guard dropped here, session state automatically cleaned up
48//! ```
49
50use std::collections::HashMap;
51use std::sync::Arc;
52
53use parking_lot::RwLock;
54use serde::{Serialize, de::DeserializeOwned};
55use serde_json::Value;
56use turbomcp_core::MaybeSend;
57
58use crate::McpError;
59use crate::types::LogLevel;
60
61use super::request::RequestContext;
62
63/// Type alias for session state storage to reduce complexity.
64type SessionStateMap = dashmap::DashMap<String, Arc<RwLock<HashMap<String, Value>>>>;
65
66/// Session state storage (keyed by session_id).
67///
68/// **⚠️  UNBOUNDED — multi-tenant servers MUST use [`SessionStateGuard`].**
69///
70/// This is a process-level singleton with no LRU/TTL bounds. Without
71/// [`SessionStateGuard`] or explicit [`cleanup_session_state`] calls, every
72/// distinct session id becomes a permanent memory entry. A long-running server
73/// that creates many short-lived sessions (e.g., per-request session ids on a
74/// public HTTP transport) will grow the map without bound until OOM. The
75/// `dashmap` crate has no built-in cap, so a custom LRU layer (`moka`,
76/// hand-rolled) is the only mitigation if `SessionStateGuard` cannot be used.
77///
78/// # Multi-Server Considerations
79///
80/// If you run multiple MCP servers in the same process (e.g., in tests or
81/// composite server scenarios), be aware that session IDs may collide.
82/// To avoid this:
83///
84/// 1. Use unique session ID prefixes per server: `"{server_name}:{session_id}"`
85/// 2. Or ensure each server uses globally unique session IDs (e.g., UUIDs)
86///
87/// This global singleton design enables session state to be shared across
88/// handler invocations without threading server references through the
89/// entire call chain.
90static SESSION_STATE: std::sync::LazyLock<SessionStateMap> =
91    std::sync::LazyLock::new(SessionStateMap::new);
92
93/// RAII guard that automatically cleans up session state when dropped.
94///
95/// This is the recommended way to manage session state lifetime. Create a guard
96/// at the start of a session and let it clean up automatically when the session
97/// ends.
98///
99/// # Example
100///
101/// ```rust,ignore
102/// use turbomcp_protocol::context::SessionStateGuard;
103///
104/// async fn handle_connection(session_id: String) {
105///     let _guard = SessionStateGuard::new(&session_id);
106///
107///     // Session state is available for this session_id
108///     // ...
109///
110/// } // State automatically cleaned up here
111/// ```
112#[derive(Debug)]
113pub struct SessionStateGuard {
114    session_id: String,
115}
116
117impl SessionStateGuard {
118    /// Create a new session state guard.
119    ///
120    /// The session's state will be automatically cleaned up when this guard
121    /// is dropped.
122    pub fn new(session_id: impl Into<String>) -> Self {
123        Self {
124            session_id: session_id.into(),
125        }
126    }
127
128    /// Get the session ID this guard is managing.
129    pub fn session_id(&self) -> &str {
130        &self.session_id
131    }
132}
133
134impl Drop for SessionStateGuard {
135    fn drop(&mut self) {
136        cleanup_session_state(&self.session_id);
137    }
138}
139
140/// Error type for state operations.
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub enum StateError {
143    /// No session ID is set on the context.
144    NoSessionId,
145    /// Failed to serialize the value.
146    SerializationFailed(String),
147    /// Failed to deserialize the value.
148    DeserializationFailed(String),
149}
150
151impl std::fmt::Display for StateError {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            Self::NoSessionId => write!(f, "no session ID set on context"),
155            Self::SerializationFailed(e) => write!(f, "serialization failed: {}", e),
156            Self::DeserializationFailed(e) => write!(f, "deserialization failed: {}", e),
157        }
158    }
159}
160
161impl std::error::Error for StateError {}
162
163/// Extension trait providing rich context capabilities.
164///
165/// This trait extends `RequestContext` with session state management,
166/// client logging, and progress reporting.
167pub trait RichContextExt {
168    // ===== State Management =====
169
170    /// Get a value from session state.
171    ///
172    /// Returns `None` if the key doesn't exist or if there's no session.
173    fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T>;
174
175    /// Try to get a value from session state with detailed error information.
176    ///
177    /// Returns `Err` if there's no session ID or deserialization fails.
178    /// Returns `Ok(None)` if the key doesn't exist.
179    fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError>;
180
181    /// Set a value in session state.
182    ///
183    /// Returns `false` if there's no session ID to store state against.
184    fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool;
185
186    /// Try to set a value in session state with detailed error information.
187    fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError>;
188
189    /// Remove a value from session state.
190    fn remove_state(&self, key: &str) -> bool;
191
192    /// Clear all session state.
193    fn clear_state(&self);
194
195    /// Check if a state key exists.
196    fn has_state(&self, key: &str) -> bool;
197
198    // ===== Client Logging =====
199
200    /// Send a debug-level log message to the client.
201    ///
202    /// Returns `Ok(())` if the notification was sent or if bidirectional
203    /// transport is not available (no-op in that case).
204    fn debug(
205        &self,
206        message: impl Into<String> + MaybeSend,
207    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
208
209    /// Send an info-level log message to the client.
210    ///
211    /// Returns `Ok(())` if the notification was sent or if bidirectional
212    /// transport is not available (no-op in that case).
213    fn info(
214        &self,
215        message: impl Into<String> + MaybeSend,
216    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
217
218    /// Send a warning-level log message to the client.
219    ///
220    /// Returns `Ok(())` if the notification was sent or if bidirectional
221    /// transport is not available (no-op in that case).
222    fn warning(
223        &self,
224        message: impl Into<String> + MaybeSend,
225    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
226
227    /// Send an error-level log message to the client.
228    ///
229    /// Returns `Ok(())` if the notification was sent or if bidirectional
230    /// transport is not available (no-op in that case).
231    fn error(
232        &self,
233        message: impl Into<String> + MaybeSend,
234    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
235
236    /// Send a log message to the client with a specific level.
237    ///
238    /// This is the low-level method that `debug`, `info`, `warning`, and `error`
239    /// delegate to. Use this when you need fine-grained control over the log level.
240    ///
241    /// Returns `Ok(())` if the notification was sent or if bidirectional
242    /// transport is not available (no-op in that case).
243    fn log(
244        &self,
245        level: LogLevel,
246        message: impl Into<String> + MaybeSend,
247        logger: Option<String>,
248    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
249
250    // ===== Progress Reporting =====
251
252    /// Report progress on a long-running operation.
253    ///
254    /// Per MCP 2025-11-25 (`schema.ts:1551-1561`), progress and total are JSON
255    /// numbers; floats are permitted to express fractional progress.
256    ///
257    /// # Arguments
258    ///
259    /// * `current` - Current progress value
260    /// * `total` - Total value (for percentage: current/total * 100)
261    /// * `message` - Optional status message
262    ///
263    /// # Example
264    ///
265    /// ```rust,ignore
266    /// for i in 0..100 {
267    ///     ctx.report_progress(i as f64, 100.0, Some(&format!("Processing item {}", i))).await?;
268    /// }
269    /// ```
270    fn report_progress(
271        &self,
272        current: f64,
273        total: f64,
274        message: Option<&str>,
275    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
276
277    /// Report progress with a custom [`ProgressToken`](crate::types::ProgressToken).
278    ///
279    /// Use this when you need to track multiple concurrent operations with
280    /// different progress tokens (per spec, `string | number`).
281    fn report_progress_with_token(
282        &self,
283        token: impl Into<crate::types::ProgressToken> + MaybeSend,
284        current: f64,
285        total: Option<f64>,
286        message: Option<&str>,
287    ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
288}
289
290impl RichContextExt for RequestContext {
291    fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
292        self.try_get_state(key).ok().flatten()
293    }
294
295    fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
296        let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
297
298        let Some(state) = SESSION_STATE.get(session_id) else {
299            return Ok(None);
300        };
301
302        let state_read = state.read();
303        let Some(value) = state_read.get(key) else {
304            return Ok(None);
305        };
306
307        serde_json::from_value(value.clone())
308            .map(Some)
309            .map_err(|e| StateError::DeserializationFailed(e.to_string()))
310    }
311
312    fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool {
313        self.try_set_state(key, value).is_ok()
314    }
315
316    fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
317        let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
318
319        let json_value = serde_json::to_value(value)
320            .map_err(|e| StateError::SerializationFailed(e.to_string()))?;
321
322        let state = SESSION_STATE
323            .entry(session_id.clone())
324            .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
325
326        state.write().insert(key.to_string(), json_value);
327        Ok(())
328    }
329
330    fn remove_state(&self, key: &str) -> bool {
331        let Some(ref session_id) = self.session_id else {
332            return false;
333        };
334
335        if let Some(state) = SESSION_STATE.get(session_id) {
336            state.write().remove(key);
337            return true;
338        }
339        false
340    }
341
342    fn clear_state(&self) {
343        if let Some(ref session_id) = self.session_id
344            && let Some(state) = SESSION_STATE.get(session_id)
345        {
346            state.write().clear();
347        }
348    }
349
350    fn has_state(&self, key: &str) -> bool {
351        if let Some(ref session_id) = self.session_id
352            && let Some(state) = SESSION_STATE.get(session_id)
353        {
354            return state.read().contains_key(key);
355        }
356        false
357    }
358
359    // ===== Client Logging =====
360
361    async fn debug(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
362        self.log(LogLevel::Debug, message, None).await
363    }
364
365    async fn info(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
366        self.log(LogLevel::Info, message, None).await
367    }
368
369    async fn warning(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
370        self.log(LogLevel::Warning, message, None).await
371    }
372
373    async fn error(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
374        self.log(LogLevel::Error, message, None).await
375    }
376
377    async fn log(
378        &self,
379        level: LogLevel,
380        message: impl Into<String> + MaybeSend,
381        logger: Option<String>,
382    ) -> Result<(), McpError> {
383        // If no bidirectional session is attached, silently succeed (no-op).
384        if !self.has_session() {
385            return Ok(());
386        }
387
388        let mut params = serde_json::json!({
389            "level": level,
390            "data": message.into(),
391        });
392        if let Some(logger) = logger {
393            params["logger"] = serde_json::Value::String(logger);
394        }
395
396        self.notify_client("notifications/message", params).await
397    }
398
399    // ===== Progress Reporting =====
400
401    async fn report_progress(
402        &self,
403        current: f64,
404        total: f64,
405        message: Option<&str>,
406    ) -> Result<(), McpError> {
407        // Use request_id as the progress token by default
408        self.report_progress_with_token(self.request_id.as_str(), current, Some(total), message)
409            .await
410    }
411
412    async fn report_progress_with_token(
413        &self,
414        token: impl Into<crate::types::ProgressToken> + MaybeSend,
415        current: f64,
416        total: Option<f64>,
417        message: Option<&str>,
418    ) -> Result<(), McpError> {
419        if !self.has_session() {
420            return Ok(());
421        }
422
423        let mut params = serde_json::json!({
424            "progressToken": token.into(),
425            "progress": current,
426        });
427        if let Some(total) = total {
428            params["total"] = serde_json::json!(total);
429        }
430        if let Some(message) = message {
431            params["message"] = serde_json::Value::String(message.to_string());
432        }
433
434        self.notify_client("notifications/progress", params).await
435    }
436}
437
438/// Clean up session state when a session ends.
439///
440/// **Important**: Call this when a session disconnects to free memory.
441/// Alternatively, use [`SessionStateGuard`] for automatic cleanup.
442///
443/// # Example
444///
445/// ```rust,ignore
446/// use turbomcp_protocol::context::cleanup_session_state;
447///
448/// fn on_session_disconnect(session_id: &str) {
449///     cleanup_session_state(session_id);
450/// }
451/// ```
452pub fn cleanup_session_state(session_id: &str) {
453    SESSION_STATE.remove(session_id);
454}
455
456/// Get the number of active sessions with state.
457///
458/// This is useful for monitoring memory usage.
459pub fn active_sessions_count() -> usize {
460    SESSION_STATE.len()
461}
462
463/// Clear all session state.
464///
465/// **Warning**: This removes state for ALL sessions. Use with caution.
466/// Primarily intended for testing.
467#[cfg(test)]
468pub fn clear_all_session_state() {
469    SESSION_STATE.clear();
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_get_set_state() {
478        let ctx = RequestContext::new().with_session_id("test-session-1");
479
480        // Set state
481        assert!(ctx.set_state("counter", &42i32));
482        assert!(ctx.set_state("name", &"Alice".to_string()));
483
484        // Get state
485        assert_eq!(ctx.get_state::<i32>("counter"), Some(42));
486        assert_eq!(ctx.get_state::<String>("name"), Some("Alice".to_string()));
487        assert_eq!(ctx.get_state::<i32>("missing"), None);
488
489        // Has state
490        assert!(ctx.has_state("counter"));
491        assert!(!ctx.has_state("missing"));
492
493        // Remove state
494        assert!(ctx.remove_state("counter"));
495        assert_eq!(ctx.get_state::<i32>("counter"), None);
496        assert!(!ctx.has_state("counter"));
497
498        // Clear state
499        ctx.clear_state();
500        assert_eq!(ctx.get_state::<String>("name"), None);
501
502        // Cleanup
503        cleanup_session_state("test-session-1");
504    }
505
506    #[test]
507    fn test_state_without_session() {
508        let ctx = RequestContext::new();
509
510        // Without session_id, state operations fail
511        assert!(!ctx.set_state("key", &"value"));
512        assert_eq!(ctx.get_state::<String>("key"), None);
513        assert!(!ctx.has_state("key"));
514
515        // try_* methods return proper errors
516        assert_eq!(
517            ctx.try_set_state("key", &"value"),
518            Err(StateError::NoSessionId)
519        );
520        assert_eq!(
521            ctx.try_get_state::<String>("key"),
522            Err(StateError::NoSessionId)
523        );
524    }
525
526    #[test]
527    fn test_state_isolation() {
528        let ctx1 = RequestContext::new().with_session_id("session-iso-1");
529        let ctx2 = RequestContext::new().with_session_id("session-iso-2");
530
531        // Set different values in different sessions
532        ctx1.set_state("value", &1i32);
533        ctx2.set_state("value", &2i32);
534
535        // Each session sees its own value
536        assert_eq!(ctx1.get_state::<i32>("value"), Some(1));
537        assert_eq!(ctx2.get_state::<i32>("value"), Some(2));
538
539        // Cleanup
540        cleanup_session_state("session-iso-1");
541        cleanup_session_state("session-iso-2");
542    }
543
544    #[test]
545    fn test_complex_types() {
546        let ctx = RequestContext::new().with_session_id("complex-session-1");
547
548        #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
549        struct MyData {
550            count: i32,
551            items: Vec<String>,
552        }
553
554        let data = MyData {
555            count: 3,
556            items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
557        };
558
559        ctx.set_state("data", &data);
560        let retrieved: Option<MyData> = ctx.get_state("data");
561        assert_eq!(retrieved, Some(data));
562
563        cleanup_session_state("complex-session-1");
564    }
565
566    #[test]
567    fn test_session_state_guard() {
568        let session_id = "guard-test-session";
569
570        {
571            let _guard = SessionStateGuard::new(session_id);
572            let ctx = RequestContext::new().with_session_id(session_id);
573
574            ctx.set_state("key", &"value");
575            assert_eq!(ctx.get_state::<String>("key"), Some("value".to_string()));
576
577            // State exists while guard is alive
578            assert!(active_sessions_count() > 0);
579        }
580
581        // After guard drops, state should be cleaned up
582        let ctx = RequestContext::new().with_session_id(session_id);
583        assert_eq!(ctx.get_state::<String>("key"), None);
584    }
585
586    #[test]
587    fn test_try_get_state_errors() {
588        let ctx = RequestContext::new().with_session_id("error-test-session");
589        ctx.set_state("number", &42i32);
590
591        // Type mismatch returns deserialization error
592        let result: Result<Option<String>, StateError> = ctx.try_get_state("number");
593        assert!(matches!(result, Err(StateError::DeserializationFailed(_))));
594
595        cleanup_session_state("error-test-session");
596    }
597
598    #[test]
599    fn test_state_error_display() {
600        assert_eq!(
601            StateError::NoSessionId.to_string(),
602            "no session ID set on context"
603        );
604        assert!(
605            StateError::SerializationFailed("test".into())
606                .to_string()
607                .contains("serialization failed")
608        );
609        assert!(
610            StateError::DeserializationFailed("test".into())
611                .to_string()
612                .contains("deserialization failed")
613        );
614    }
615
616    #[tokio::test]
617    async fn test_logging_without_server_to_client() {
618        // Without server_to_client configured, logging methods should be no-ops
619        let ctx = RequestContext::new().with_session_id("logging-test");
620
621        // These should all succeed (no-op) without server_to_client
622        assert!(ctx.debug("debug message").await.is_ok());
623        assert!(ctx.info("info message").await.is_ok());
624        assert!(ctx.warning("warning message").await.is_ok());
625        assert!(ctx.error("error message").await.is_ok());
626        assert!(ctx.log(LogLevel::Notice, "notice", None).await.is_ok());
627    }
628
629    #[tokio::test]
630    async fn test_progress_without_server_to_client() {
631        // Without server_to_client configured, progress methods should be no-ops
632        let ctx = RequestContext::new().with_session_id("progress-test");
633
634        // These should all succeed (no-op) without server_to_client
635        assert!(
636            ctx.report_progress(50.0, 100.0, Some("halfway"))
637                .await
638                .is_ok()
639        );
640        assert!(ctx.report_progress(100.0, 100.0, None).await.is_ok());
641        assert!(
642            ctx.report_progress_with_token("custom-token", 25.0, Some(100.0), Some("processing"))
643                .await
644                .is_ok()
645        );
646    }
647}