turbomcp_protocol/context/
rich.rs1use std::collections::HashMap;
51use std::sync::Arc;
52
53use parking_lot::RwLock;
54use serde::{Serialize, de::DeserializeOwned};
55use serde_json::Value;
56use turbomcp_core::MaybeSend;
57
58use crate::McpError;
59use crate::types::LogLevel;
60
61use super::request::RequestContext;
62
63type SessionStateMap = dashmap::DashMap<String, Arc<RwLock<HashMap<String, Value>>>>;
65
66static SESSION_STATE: std::sync::LazyLock<SessionStateMap> =
91 std::sync::LazyLock::new(SessionStateMap::new);
92
93#[derive(Debug)]
113pub struct SessionStateGuard {
114 session_id: String,
115}
116
117impl SessionStateGuard {
118 pub fn new(session_id: impl Into<String>) -> Self {
123 Self {
124 session_id: session_id.into(),
125 }
126 }
127
128 pub fn session_id(&self) -> &str {
130 &self.session_id
131 }
132}
133
134impl Drop for SessionStateGuard {
135 fn drop(&mut self) {
136 cleanup_session_state(&self.session_id);
137 }
138}
139
140#[derive(Debug, Clone, PartialEq, Eq)]
142pub enum StateError {
143 NoSessionId,
145 SerializationFailed(String),
147 DeserializationFailed(String),
149}
150
151impl std::fmt::Display for StateError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 Self::NoSessionId => write!(f, "no session ID set on context"),
155 Self::SerializationFailed(e) => write!(f, "serialization failed: {}", e),
156 Self::DeserializationFailed(e) => write!(f, "deserialization failed: {}", e),
157 }
158 }
159}
160
161impl std::error::Error for StateError {}
162
163pub trait RichContextExt {
168 fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T>;
174
175 fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError>;
180
181 fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool;
185
186 fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError>;
188
189 fn remove_state(&self, key: &str) -> bool;
191
192 fn clear_state(&self);
194
195 fn has_state(&self, key: &str) -> bool;
197
198 fn debug(
205 &self,
206 message: impl Into<String> + MaybeSend,
207 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
208
209 fn info(
214 &self,
215 message: impl Into<String> + MaybeSend,
216 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
217
218 fn warning(
223 &self,
224 message: impl Into<String> + MaybeSend,
225 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
226
227 fn error(
232 &self,
233 message: impl Into<String> + MaybeSend,
234 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
235
236 fn log(
244 &self,
245 level: LogLevel,
246 message: impl Into<String> + MaybeSend,
247 logger: Option<String>,
248 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
249
250 fn report_progress(
271 &self,
272 current: f64,
273 total: f64,
274 message: Option<&str>,
275 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
276
277 fn report_progress_with_token(
282 &self,
283 token: impl Into<crate::types::ProgressToken> + MaybeSend,
284 current: f64,
285 total: Option<f64>,
286 message: Option<&str>,
287 ) -> impl std::future::Future<Output = Result<(), McpError>> + MaybeSend;
288}
289
290impl RichContextExt for RequestContext {
291 fn get_state<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
292 self.try_get_state(key).ok().flatten()
293 }
294
295 fn try_get_state<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StateError> {
296 let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
297
298 let Some(state) = SESSION_STATE.get(session_id) else {
299 return Ok(None);
300 };
301
302 let state_read = state.read();
303 let Some(value) = state_read.get(key) else {
304 return Ok(None);
305 };
306
307 serde_json::from_value(value.clone())
308 .map(Some)
309 .map_err(|e| StateError::DeserializationFailed(e.to_string()))
310 }
311
312 fn set_state<T: Serialize>(&self, key: &str, value: &T) -> bool {
313 self.try_set_state(key, value).is_ok()
314 }
315
316 fn try_set_state<T: Serialize>(&self, key: &str, value: &T) -> Result<(), StateError> {
317 let session_id = self.session_id.as_ref().ok_or(StateError::NoSessionId)?;
318
319 let json_value = serde_json::to_value(value)
320 .map_err(|e| StateError::SerializationFailed(e.to_string()))?;
321
322 let state = SESSION_STATE
323 .entry(session_id.clone())
324 .or_insert_with(|| Arc::new(RwLock::new(HashMap::new())));
325
326 state.write().insert(key.to_string(), json_value);
327 Ok(())
328 }
329
330 fn remove_state(&self, key: &str) -> bool {
331 let Some(ref session_id) = self.session_id else {
332 return false;
333 };
334
335 if let Some(state) = SESSION_STATE.get(session_id) {
336 state.write().remove(key);
337 return true;
338 }
339 false
340 }
341
342 fn clear_state(&self) {
343 if let Some(ref session_id) = self.session_id
344 && let Some(state) = SESSION_STATE.get(session_id)
345 {
346 state.write().clear();
347 }
348 }
349
350 fn has_state(&self, key: &str) -> bool {
351 if let Some(ref session_id) = self.session_id
352 && let Some(state) = SESSION_STATE.get(session_id)
353 {
354 return state.read().contains_key(key);
355 }
356 false
357 }
358
359 async fn debug(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
362 self.log(LogLevel::Debug, message, None).await
363 }
364
365 async fn info(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
366 self.log(LogLevel::Info, message, None).await
367 }
368
369 async fn warning(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
370 self.log(LogLevel::Warning, message, None).await
371 }
372
373 async fn error(&self, message: impl Into<String> + MaybeSend) -> Result<(), McpError> {
374 self.log(LogLevel::Error, message, None).await
375 }
376
377 async fn log(
378 &self,
379 level: LogLevel,
380 message: impl Into<String> + MaybeSend,
381 logger: Option<String>,
382 ) -> Result<(), McpError> {
383 if !self.has_session() {
385 return Ok(());
386 }
387
388 let mut params = serde_json::json!({
389 "level": level,
390 "data": message.into(),
391 });
392 if let Some(logger) = logger {
393 params["logger"] = serde_json::Value::String(logger);
394 }
395
396 self.notify_client("notifications/message", params).await
397 }
398
399 async fn report_progress(
402 &self,
403 current: f64,
404 total: f64,
405 message: Option<&str>,
406 ) -> Result<(), McpError> {
407 self.report_progress_with_token(self.request_id.as_str(), current, Some(total), message)
409 .await
410 }
411
412 async fn report_progress_with_token(
413 &self,
414 token: impl Into<crate::types::ProgressToken> + MaybeSend,
415 current: f64,
416 total: Option<f64>,
417 message: Option<&str>,
418 ) -> Result<(), McpError> {
419 if !self.has_session() {
420 return Ok(());
421 }
422
423 let mut params = serde_json::json!({
424 "progressToken": token.into(),
425 "progress": current,
426 });
427 if let Some(total) = total {
428 params["total"] = serde_json::json!(total);
429 }
430 if let Some(message) = message {
431 params["message"] = serde_json::Value::String(message.to_string());
432 }
433
434 self.notify_client("notifications/progress", params).await
435 }
436}
437
438pub fn cleanup_session_state(session_id: &str) {
453 SESSION_STATE.remove(session_id);
454}
455
456pub fn active_sessions_count() -> usize {
460 SESSION_STATE.len()
461}
462
463#[cfg(test)]
468pub fn clear_all_session_state() {
469 SESSION_STATE.clear();
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_get_set_state() {
478 let ctx = RequestContext::new().with_session_id("test-session-1");
479
480 assert!(ctx.set_state("counter", &42i32));
482 assert!(ctx.set_state("name", &"Alice".to_string()));
483
484 assert_eq!(ctx.get_state::<i32>("counter"), Some(42));
486 assert_eq!(ctx.get_state::<String>("name"), Some("Alice".to_string()));
487 assert_eq!(ctx.get_state::<i32>("missing"), None);
488
489 assert!(ctx.has_state("counter"));
491 assert!(!ctx.has_state("missing"));
492
493 assert!(ctx.remove_state("counter"));
495 assert_eq!(ctx.get_state::<i32>("counter"), None);
496 assert!(!ctx.has_state("counter"));
497
498 ctx.clear_state();
500 assert_eq!(ctx.get_state::<String>("name"), None);
501
502 cleanup_session_state("test-session-1");
504 }
505
506 #[test]
507 fn test_state_without_session() {
508 let ctx = RequestContext::new();
509
510 assert!(!ctx.set_state("key", &"value"));
512 assert_eq!(ctx.get_state::<String>("key"), None);
513 assert!(!ctx.has_state("key"));
514
515 assert_eq!(
517 ctx.try_set_state("key", &"value"),
518 Err(StateError::NoSessionId)
519 );
520 assert_eq!(
521 ctx.try_get_state::<String>("key"),
522 Err(StateError::NoSessionId)
523 );
524 }
525
526 #[test]
527 fn test_state_isolation() {
528 let ctx1 = RequestContext::new().with_session_id("session-iso-1");
529 let ctx2 = RequestContext::new().with_session_id("session-iso-2");
530
531 ctx1.set_state("value", &1i32);
533 ctx2.set_state("value", &2i32);
534
535 assert_eq!(ctx1.get_state::<i32>("value"), Some(1));
537 assert_eq!(ctx2.get_state::<i32>("value"), Some(2));
538
539 cleanup_session_state("session-iso-1");
541 cleanup_session_state("session-iso-2");
542 }
543
544 #[test]
545 fn test_complex_types() {
546 let ctx = RequestContext::new().with_session_id("complex-session-1");
547
548 #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
549 struct MyData {
550 count: i32,
551 items: Vec<String>,
552 }
553
554 let data = MyData {
555 count: 3,
556 items: vec!["a".to_string(), "b".to_string(), "c".to_string()],
557 };
558
559 ctx.set_state("data", &data);
560 let retrieved: Option<MyData> = ctx.get_state("data");
561 assert_eq!(retrieved, Some(data));
562
563 cleanup_session_state("complex-session-1");
564 }
565
566 #[test]
567 fn test_session_state_guard() {
568 let session_id = "guard-test-session";
569
570 {
571 let _guard = SessionStateGuard::new(session_id);
572 let ctx = RequestContext::new().with_session_id(session_id);
573
574 ctx.set_state("key", &"value");
575 assert_eq!(ctx.get_state::<String>("key"), Some("value".to_string()));
576
577 assert!(active_sessions_count() > 0);
579 }
580
581 let ctx = RequestContext::new().with_session_id(session_id);
583 assert_eq!(ctx.get_state::<String>("key"), None);
584 }
585
586 #[test]
587 fn test_try_get_state_errors() {
588 let ctx = RequestContext::new().with_session_id("error-test-session");
589 ctx.set_state("number", &42i32);
590
591 let result: Result<Option<String>, StateError> = ctx.try_get_state("number");
593 assert!(matches!(result, Err(StateError::DeserializationFailed(_))));
594
595 cleanup_session_state("error-test-session");
596 }
597
598 #[test]
599 fn test_state_error_display() {
600 assert_eq!(
601 StateError::NoSessionId.to_string(),
602 "no session ID set on context"
603 );
604 assert!(
605 StateError::SerializationFailed("test".into())
606 .to_string()
607 .contains("serialization failed")
608 );
609 assert!(
610 StateError::DeserializationFailed("test".into())
611 .to_string()
612 .contains("deserialization failed")
613 );
614 }
615
616 #[tokio::test]
617 async fn test_logging_without_server_to_client() {
618 let ctx = RequestContext::new().with_session_id("logging-test");
620
621 assert!(ctx.debug("debug message").await.is_ok());
623 assert!(ctx.info("info message").await.is_ok());
624 assert!(ctx.warning("warning message").await.is_ok());
625 assert!(ctx.error("error message").await.is_ok());
626 assert!(ctx.log(LogLevel::Notice, "notice", None).await.is_ok());
627 }
628
629 #[tokio::test]
630 async fn test_progress_without_server_to_client() {
631 let ctx = RequestContext::new().with_session_id("progress-test");
633
634 assert!(
636 ctx.report_progress(50.0, 100.0, Some("halfway"))
637 .await
638 .is_ok()
639 );
640 assert!(ctx.report_progress(100.0, 100.0, None).await.is_ok());
641 assert!(
642 ctx.report_progress_with_token("custom-token", 25.0, Some(100.0), Some("processing"))
643 .await
644 .is_ok()
645 );
646 }
647}