turbomcp_protocol/context/
rich.rs1use std::collections::HashMap;
51use std::sync::Arc;
52
53use parking_lot::RwLock;
54use serde::{Serialize, de::DeserializeOwned};
55use serde_json::Value;
56
57use crate::McpError;
58use crate::types::{LogLevel, LoggingNotification, ProgressNotification, ServerNotification};
59
60use super::request::RequestContext;
61
62type SessionStateMap = dashmap::DashMap<String, Arc<RwLock<HashMap<String, Value>>>>;
64
65static SESSION_STATE: std::sync::LazyLock<SessionStateMap> =
84 std::sync::LazyLock::new(SessionStateMap::new);
85
86#[derive(Debug)]
106pub struct SessionStateGuard {
107 session_id: String,
108}
109
110impl SessionStateGuard {
111 pub fn new(session_id: impl Into<String>) -> Self {
116 Self {
117 session_id: session_id.into(),
118 }
119 }
120
121 pub fn session_id(&self) -> &str {
123 &self.session_id
124 }
125}
126
127impl Drop for SessionStateGuard {
128 fn drop(&mut self) {
129 cleanup_session_state(&self.session_id);
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
135pub enum StateError {
136 NoSessionId,
138 SerializationFailed(String),
140 DeserializationFailed(String),
142}
143
144impl std::fmt::Display for StateError {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 Self::NoSessionId => write!(f, "no session ID set on context"),
148 Self::SerializationFailed(e) => write!(f, "serialization failed: {}", e),
149 Self::DeserializationFailed(e) => write!(f, "deserialization failed: {}", e),
150 }
151 }
152}
153
154impl std::error::Error for StateError {}
155
156pub trait RichContextExt {
161 fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T>;
167
168 fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError>;
173
174 fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool;
178
179 fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError>;
181
182 fn remove_state(&self, key: &str) -> bool;
184
185 fn clear_state(&self);
187
188 fn has_state(&self, key: &str) -> bool;
190
191 fn debug(
198 &self,
199 message: impl Into<String> + Send,
200 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
201
202 fn info(
207 &self,
208 message: impl Into<String> + Send,
209 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
210
211 fn warning(
216 &self,
217 message: impl Into<String> + Send,
218 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
219
220 fn error(
225 &self,
226 message: impl Into<String> + Send,
227 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
228
229 fn log(
237 &self,
238 level: LogLevel,
239 message: impl Into<String> + Send,
240 logger: Option<String>,
241 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
242
243 fn report_progress(
261 &self,
262 current: u64,
263 total: u64,
264 message: Option<&str>,
265 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
266
267 fn report_progress_with_token(
272 &self,
273 token: impl Into<String> + Send,
274 current: u64,
275 total: Option<u64>,
276 message: Option<&str>,
277 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
278}
279
280impl RichContextExt for RequestContext {
281 fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
282 self.try_get_state(key).ok().flatten()
283 }
284
285 fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
286 let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
287
288 let Some(state) = SESSION_STATE.get(session_id) else {
289 return Ok(None);
290 };
291
292 let state_read = state.read();
293 let Some(value) = state_read.get(key) else {
294 return Ok(None);
295 };
296
297 serde_json::from_value(value.clone())
298 .map(Some)
299 .map_err(|e| StateError::DeserializationFailed(e.to_string()))
300 }
301
302 fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool {
303 self.try_set_state(key, value).is_ok()
304 }
305
306 fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
307 let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
308
309 let json_value = serde_json::to_value(value)
310 .map_err(|e| StateError::SerializationFailed(e.to_string()))?;
311
312 let state = SESSION_STATE
313 .entry(session_id.clone())
314 .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
315
316 state.write().insert(key.to_string(), json_value);
317 Ok(())
318 }
319
320 fn remove_state(&self, key: &str) -> bool {
321 let Some(ref session_id) = self.session_id else {
322 return false;
323 };
324
325 if let Some(state) = SESSION_STATE.get(session_id) {
326 state.write().remove(key);
327 return true;
328 }
329 false
330 }
331
332 fn clear_state(&self) {
333 if let Some(ref session_id) = self.session_id
334 && let Some(state) = SESSION_STATE.get(session_id)
335 {
336 state.write().clear();
337 }
338 }
339
340 fn has_state(&self, key: &str) -> bool {
341 if let Some(ref session_id) = self.session_id
342 && let Some(state) = SESSION_STATE.get(session_id)
343 {
344 return state.read().contains_key(key);
345 }
346 false
347 }
348
349 async fn debug(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
352 self.log(LogLevel::Debug, message, None).await
353 }
354
355 async fn info(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
356 self.log(LogLevel::Info, message, None).await
357 }
358
359 async fn warning(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
360 self.log(LogLevel::Warning, message, None).await
361 }
362
363 async fn error(&self, message: impl Into<String> + Send) -> Result<(), McpError> {
364 self.log(LogLevel::Error, message, None).await
365 }
366
367 async fn log(
368 &self,
369 level: LogLevel,
370 message: impl Into<String> + Send,
371 logger: Option<String>,
372 ) -> Result<(), McpError> {
373 let Some(s2c) = self.server_to_client() else {
375 return Ok(());
376 };
377
378 let notification = ServerNotification::Message(LoggingNotification {
379 level,
380 data: serde_json::Value::String(message.into()),
381 logger,
382 });
383
384 s2c.send_notification(notification).await
385 }
386
387 async fn report_progress(
390 &self,
391 current: u64,
392 total: u64,
393 message: Option<&str>,
394 ) -> Result<(), McpError> {
395 self.report_progress_with_token(&self.request_id, current, Some(total), message)
397 .await
398 }
399
400 async fn report_progress_with_token(
401 &self,
402 token: impl Into<String> + Send,
403 current: u64,
404 total: Option<u64>,
405 message: Option<&str>,
406 ) -> Result<(), McpError> {
407 let Some(s2c) = self.server_to_client() else {
409 return Ok(());
410 };
411
412 let notification = ServerNotification::Progress(ProgressNotification {
413 progress_token: token.into(),
414 progress: current,
415 total,
416 message: message.map(ToString::to_string),
417 });
418
419 s2c.send_notification(notification).await
420 }
421}
422
423pub fn cleanup_session_state(session_id: &str) {
438 SESSION_STATE.remove(session_id);
439}
440
441pub fn active_sessions_count() -> usize {
445 SESSION_STATE.len()
446}
447
448#[cfg(test)]
453pub fn clear_all_session_state() {
454 SESSION_STATE.clear();
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_get_set_state() {
463 let ctx = RequestContext::new().with_session_id("test-session-1");
464
465 assert!(ctx.set_state("counter", &42i32));
467 assert!(ctx.set_state("name", &"Alice".to_string()));
468
469 assert_eq!(ctx.get_state::<i32>("counter"), Some(42));
471 assert_eq!(ctx.get_state::<String>("name"), Some("Alice".to_string()));
472 assert_eq!(ctx.get_state::<i32>("missing"), None);
473
474 assert!(ctx.has_state("counter"));
476 assert!(!ctx.has_state("missing"));
477
478 assert!(ctx.remove_state("counter"));
480 assert_eq!(ctx.get_state::<i32>("counter"), None);
481 assert!(!ctx.has_state("counter"));
482
483 ctx.clear_state();
485 assert_eq!(ctx.get_state::<String>("name"), None);
486
487 cleanup_session_state("test-session-1");
489 }
490
491 #[test]
492 fn test_state_without_session() {
493 let ctx = RequestContext::new();
494
495 assert!(!ctx.set_state("key", &"value"));
497 assert_eq!(ctx.get_state::<String>("key"), None);
498 assert!(!ctx.has_state("key"));
499
500 assert_eq!(
502 ctx.try_set_state("key", &"value"),
503 Err(StateError::NoSessionId)
504 );
505 assert_eq!(
506 ctx.try_get_state::<String>("key"),
507 Err(StateError::NoSessionId)
508 );
509 }
510
511 #[test]
512 fn test_state_isolation() {
513 let ctx1 = RequestContext::new().with_session_id("session-iso-1");
514 let ctx2 = RequestContext::new().with_session_id("session-iso-2");
515
516 ctx1.set_state("value", &1i32);
518 ctx2.set_state("value", &2i32);
519
520 assert_eq!(ctx1.get_state::<i32>("value"), Some(1));
522 assert_eq!(ctx2.get_state::<i32>("value"), Some(2));
523
524 cleanup_session_state("session-iso-1");
526 cleanup_session_state("session-iso-2");
527 }
528
529 #[test]
530 fn test_complex_types() {
531 let ctx = RequestContext::new().with_session_id("complex-session-1");
532
533 #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
534 struct MyData {
535 count: i32,
536 items: Vec<String>,
537 }
538
539 let data = MyData {
540 count: 3,
541 items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
542 };
543
544 ctx.set_state("data", &data);
545 let retrieved: Option<MyData> = ctx.get_state("data");
546 assert_eq!(retrieved, Some(data));
547
548 cleanup_session_state("complex-session-1");
549 }
550
551 #[test]
552 fn test_session_state_guard() {
553 let session_id = "guard-test-session";
554
555 {
556 let _guard = SessionStateGuard::new(session_id);
557 let ctx = RequestContext::new().with_session_id(session_id);
558
559 ctx.set_state("key", &"value");
560 assert_eq!(ctx.get_state::<String>("key"), Some("value".to_string()));
561
562 assert!(active_sessions_count() > 0);
564 }
565
566 let ctx = RequestContext::new().with_session_id(session_id);
568 assert_eq!(ctx.get_state::<String>("key"), None);
569 }
570
571 #[test]
572 fn test_try_get_state_errors() {
573 let ctx = RequestContext::new().with_session_id("error-test-session");
574 ctx.set_state("number", &42i32);
575
576 let result: Result<Option<String>, StateError> = ctx.try_get_state("number");
578 assert!(matches!(result, Err(StateError::DeserializationFailed(_))));
579
580 cleanup_session_state("error-test-session");
581 }
582
583 #[test]
584 fn test_state_error_display() {
585 assert_eq!(
586 StateError::NoSessionId.to_string(),
587 "no session ID set on context"
588 );
589 assert!(
590 StateError::SerializationFailed("test".into())
591 .to_string()
592 .contains("serialization failed")
593 );
594 assert!(
595 StateError::DeserializationFailed("test".into())
596 .to_string()
597 .contains("deserialization failed")
598 );
599 }
600
601 #[tokio::test]
602 async fn test_logging_without_server_to_client() {
603 let ctx = RequestContext::new().with_session_id("logging-test");
605
606 assert!(ctx.debug("debug message").await.is_ok());
608 assert!(ctx.info("info message").await.is_ok());
609 assert!(ctx.warning("warning message").await.is_ok());
610 assert!(ctx.error("error message").await.is_ok());
611 assert!(ctx.log(LogLevel::Notice, "notice", None).await.is_ok());
612 }
613
614 #[tokio::test]
615 async fn test_progress_without_server_to_client() {
616 let ctx = RequestContext::new().with_session_id("progress-test");
618
619 assert!(ctx.report_progress(50, 100, Some("halfway")).await.is_ok());
621 assert!(ctx.report_progress(100, 100, None).await.is_ok());
622 assert!(
623 ctx.report_progress_with_token("custom-token", 25, Some(100), Some("processing"))
624 .await
625 .is_ok()
626 );
627 }
628}