1use serde_json::Value;
7use std::borrow::Cow;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::{debug, error, info, warn};
11
12use ultrafast_mcp_core::{
13 error::MCPResult,
14 protocol::jsonrpc::{JsonRpcMessage, JsonRpcRequest},
15 types::notifications::{LogLevel, LoggingMessageNotification, ProgressNotification},
16};
17
18#[derive(Debug, Clone)]
20pub struct CancellationManager {
21 cancelled_requests: Arc<tokio::sync::RwLock<std::collections::HashSet<String>>>,
22}
23
24impl CancellationManager {
25 pub fn new() -> Self {
26 Self {
27 cancelled_requests: Arc::new(
28 tokio::sync::RwLock::new(std::collections::HashSet::new()),
29 ),
30 }
31 }
32
33 pub async fn cancel_request(&self, request_id: &str) {
34 let mut requests = self.cancelled_requests.write().await;
35 requests.insert(request_id.to_string());
36 }
37
38 pub async fn is_cancelled(&self, request_id: &str) -> bool {
39 let requests = self.cancelled_requests.read().await;
40 requests.contains(request_id)
41 }
42
43 pub async fn clear_cancelled(&self, request_id: &str) {
44 let mut requests = self.cancelled_requests.write().await;
45 requests.remove(request_id);
46 }
47}
48
49impl Default for CancellationManager {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct LoggerConfig {
58 pub min_level: LogLevel,
60 pub send_notifications: bool,
62 pub structured_output: bool,
64 pub max_message_length: usize,
66 pub include_timestamps: bool,
68 pub include_logger_name: bool,
70 pub logger_name: Option<String>,
72}
73
74impl Default for LoggerConfig {
75 fn default() -> Self {
76 Self {
77 min_level: LogLevel::Info,
78 send_notifications: true,
79 structured_output: true,
80 max_message_length: 4096,
81 include_timestamps: true,
82 include_logger_name: true,
83 logger_name: None,
84 }
85 }
86}
87
88type NotificationSender = Arc<
90 dyn Fn(
91 JsonRpcMessage,
92 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MCPResult<()>> + Send>>
93 + Send
94 + Sync,
95>;
96
97#[derive(Clone)]
102pub struct Context {
103 session_id: Option<String>,
104 request_id: Option<String>,
105 metadata: HashMap<String, serde_json::Value>,
106 logger_config: LoggerConfig,
107 notification_sender: Option<NotificationSender>,
108 cancellation_manager: Option<Arc<CancellationManager>>,
109}
110
111impl std::fmt::Debug for Context {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("Context")
114 .field("session_id", &self.session_id)
115 .field("request_id", &self.request_id)
116 .field("metadata", &self.metadata)
117 .field("logger_config", &self.logger_config)
118 .field("notification_sender", &self.notification_sender.is_some())
119 .finish()
120 }
121}
122
123impl Context {
124 pub fn new() -> Self {
126 Self {
127 session_id: None,
128 request_id: None,
129 metadata: HashMap::new(),
130 logger_config: LoggerConfig::default(),
131 notification_sender: None,
132 cancellation_manager: None,
133 }
134 }
135
136 pub fn with_session_id(mut self, session_id: String) -> Self {
138 self.session_id = Some(session_id);
139 self
140 }
141
142 pub fn with_request_id(mut self, request_id: String) -> Self {
144 self.request_id = Some(request_id);
145 self
146 }
147
148 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
150 self.metadata.insert(key, value);
151 self
152 }
153
154 pub fn with_logger_config(mut self, config: LoggerConfig) -> Self {
156 self.logger_config = config;
157 self
158 }
159
160 pub fn with_notification_sender(mut self, sender: NotificationSender) -> Self {
162 self.notification_sender = Some(sender);
163 self
164 }
165
166 pub fn with_cancellation_manager(mut self, manager: Arc<CancellationManager>) -> Self {
168 self.cancellation_manager = Some(manager);
169 self
170 }
171
172 pub fn session_id(&self) -> Option<&str> {
174 self.session_id.as_deref()
175 }
176
177 pub fn request_id(&self) -> Option<&str> {
179 self.request_id.as_deref()
180 }
181
182 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
184 self.metadata.get(key)
185 }
186
187 pub fn set_log_level(&mut self, level: LogLevel) {
189 self.logger_config.min_level = level;
190 }
191
192 pub fn get_log_level(&self) -> &LogLevel {
194 &self.logger_config.min_level
195 }
196
197 fn should_log(&self, level: &LogLevel) -> bool {
199 let level_priority = log_level_priority(level);
200 let min_priority = log_level_priority(&self.logger_config.min_level);
201 level_priority >= min_priority
202 }
203
204 pub async fn progress(
211 &self,
212 message: &str,
213 progress: f64,
214 total: Option<f64>,
215 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
216 if let Some(total) = total {
218 info!(
219 "Progress: {} - {:.2}/{:.2} ({:.1}%)",
220 message,
221 progress,
222 total,
223 (progress / total) * 100.0
224 );
225 } else {
226 info!("Progress: {} - {:.2}", message, progress);
227 }
228
229 if let Some(sender) = &self.notification_sender {
231 let progress_token = self
232 .request_id()
233 .map(|id| serde_json::Value::String(id.to_string()))
234 .unwrap_or(serde_json::Value::Null);
235
236 let mut notification = ProgressNotification::new(progress_token, progress)
237 .with_message(message.to_string());
238
239 if let Some(total) = total {
240 notification = notification.with_total(total);
241 }
242
243 let notification_request = JsonRpcRequest {
244 jsonrpc: Cow::Borrowed("2.0"),
245 id: None, method: "notifications/progress".to_string(),
247 params: Some(serde_json::to_value(notification)?),
248 meta: std::collections::HashMap::new(),
249 };
250
251 sender(JsonRpcMessage::Notification(notification_request)).await?;
252 }
253
254 Ok(())
255 }
256
257 pub async fn log_debug(
259 &self,
260 message: &str,
261 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
262 self.log_with_level(LogLevel::Debug, message, None).await
263 }
264
265 pub async fn log_debug_structured(
267 &self,
268 message: &str,
269 data: Value,
270 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
271 self.log_with_level(LogLevel::Debug, message, Some(data))
272 .await
273 }
274
275 pub async fn log_info(
277 &self,
278 message: &str,
279 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
280 self.log_with_level(LogLevel::Info, message, None).await
281 }
282
283 pub async fn log_info_structured(
285 &self,
286 message: &str,
287 data: Value,
288 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
289 self.log_with_level(LogLevel::Info, message, Some(data))
290 .await
291 }
292
293 pub async fn log_notice(
295 &self,
296 message: &str,
297 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
298 self.log_with_level(LogLevel::Notice, message, None).await
299 }
300
301 pub async fn log_notice_structured(
303 &self,
304 message: &str,
305 data: Value,
306 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
307 self.log_with_level(LogLevel::Notice, message, Some(data))
308 .await
309 }
310
311 pub async fn log_warn(
313 &self,
314 message: &str,
315 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
316 self.log_with_level(LogLevel::Warning, message, None).await
317 }
318
319 pub async fn log_warn_structured(
321 &self,
322 message: &str,
323 data: Value,
324 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
325 self.log_with_level(LogLevel::Warning, message, Some(data))
326 .await
327 }
328
329 pub async fn log_error(
331 &self,
332 message: &str,
333 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
334 self.log_with_level(LogLevel::Error, message, None).await
335 }
336
337 pub async fn log_error_structured(
339 &self,
340 message: &str,
341 data: Value,
342 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
343 self.log_with_level(LogLevel::Error, message, Some(data))
344 .await
345 }
346
347 pub async fn log_critical(
349 &self,
350 message: &str,
351 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
352 self.log_with_level(LogLevel::Critical, message, None).await
353 }
354
355 pub async fn log_critical_structured(
357 &self,
358 message: &str,
359 data: Value,
360 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
361 self.log_with_level(LogLevel::Critical, message, Some(data))
362 .await
363 }
364
365 pub async fn log_alert(
367 &self,
368 message: &str,
369 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
370 self.log_with_level(LogLevel::Alert, message, None).await
371 }
372
373 pub async fn log_alert_structured(
375 &self,
376 message: &str,
377 data: Value,
378 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
379 self.log_with_level(LogLevel::Alert, message, Some(data))
380 .await
381 }
382
383 pub async fn log_emergency(
385 &self,
386 message: &str,
387 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
388 self.log_with_level(LogLevel::Emergency, message, None)
389 .await
390 }
391
392 pub async fn log_emergency_structured(
394 &self,
395 message: &str,
396 data: Value,
397 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
398 self.log_with_level(LogLevel::Emergency, message, Some(data))
399 .await
400 }
401
402 async fn log_with_level(
404 &self,
405 level: LogLevel,
406 message: &str,
407 structured_data: Option<Value>,
408 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
409 if !self.should_log(&level) {
411 return Ok(());
412 }
413
414 let truncated_message = if message.len() > self.logger_config.max_message_length {
416 let mut truncated = message[..self.logger_config.max_message_length - 3].to_string();
417 truncated.push_str("...");
418 truncated
419 } else {
420 message.to_string()
421 };
422
423 let log_data = if self.logger_config.structured_output {
425 let mut data_obj = serde_json::Map::new();
426
427 data_obj.insert(
429 "message".to_string(),
430 Value::String(truncated_message.clone()),
431 );
432
433 if let Some(request_id) = &self.request_id {
435 data_obj.insert("request_id".to_string(), Value::String(request_id.clone()));
436 }
437
438 if let Some(session_id) = &self.session_id {
439 data_obj.insert("session_id".to_string(), Value::String(session_id.clone()));
440 }
441
442 if self.logger_config.include_timestamps {
444 let timestamp = chrono::Utc::now().to_rfc3339();
445 data_obj.insert("timestamp".to_string(), Value::String(timestamp));
446 }
447
448 if self.logger_config.include_logger_name {
450 let logger_name = self
451 .logger_config
452 .logger_name
453 .as_deref()
454 .unwrap_or("ultrafast-mcp-server");
455 data_obj.insert("logger".to_string(), Value::String(logger_name.to_string()));
456 }
457
458 data_obj.insert(
460 "level".to_string(),
461 Value::String(format!("{level:?}").to_lowercase()),
462 );
463
464 if let Some(data) = structured_data {
466 data_obj.insert("data".to_string(), data);
467 }
468
469 if !self.metadata.is_empty() {
471 data_obj.insert(
472 "metadata".to_string(),
473 Value::Object(
474 self.metadata
475 .iter()
476 .map(|(k, v)| (k.clone(), v.clone()))
477 .collect(),
478 ),
479 );
480 }
481
482 Value::Object(data_obj)
483 } else {
484 Value::String(truncated_message.clone())
486 };
487
488 let request_context = self.request_id.as_deref().unwrap_or("unknown");
490 match level {
491 LogLevel::Debug => debug!("[{}] {}", request_context, truncated_message),
492 LogLevel::Info => info!("[{}] {}", request_context, truncated_message),
493 LogLevel::Notice => info!("[{}] NOTICE: {}", request_context, truncated_message),
494 LogLevel::Warning => warn!("[{}] {}", request_context, truncated_message),
495 LogLevel::Error => error!("[{}] {}", request_context, truncated_message),
496 LogLevel::Critical => error!("[{}] CRITICAL: {}", request_context, truncated_message),
497 LogLevel::Alert => error!("[{}] ALERT: {}", request_context, truncated_message),
498 LogLevel::Emergency => error!("[{}] EMERGENCY: {}", request_context, truncated_message),
499 }
500
501 if self.logger_config.send_notifications {
503 if let Some(sender) = &self.notification_sender {
504 let logger_name = self
505 .logger_config
506 .logger_name
507 .as_deref()
508 .unwrap_or("ultrafast-mcp-server");
509
510 let notification = LoggingMessageNotification::new(level, log_data)
511 .with_logger(logger_name.to_string());
512
513 let notification_request = JsonRpcRequest {
514 jsonrpc: Cow::Borrowed("2.0"),
515 id: None, method: "notifications/message".to_string(),
517 params: Some(serde_json::to_value(notification)?),
518 meta: std::collections::HashMap::new(),
519 };
520
521 if let Err(e) = sender(JsonRpcMessage::Notification(notification_request)).await {
523 error!("Failed to send logging notification: {}", e);
525 }
526 }
527 }
528
529 Ok(())
530 }
531
532 pub async fn is_cancelled(&self) -> bool {
534 if let Some(cancellation_manager) = &self.cancellation_manager {
535 if let Some(request_id) = &self.request_id {
536 cancellation_manager.is_cancelled(request_id).await
537 } else {
538 false
539 }
540 } else {
541 false
542 }
543 }
544}
545
546impl Default for Context {
547 fn default() -> Self {
548 Self::new()
549 }
550}
551
552fn log_level_priority(level: &LogLevel) -> u8 {
554 match level {
555 LogLevel::Debug => 0,
556 LogLevel::Info => 1,
557 LogLevel::Notice => 2,
558 LogLevel::Warning => 3,
559 LogLevel::Error => 4,
560 LogLevel::Critical => 5,
561 LogLevel::Alert => 6,
562 LogLevel::Emergency => 7,
563 }
564}
565
566pub struct ContextLogger {
568 config: LoggerConfig,
569}
570
571impl ContextLogger {
572 pub fn new() -> Self {
573 Self {
574 config: LoggerConfig::default(),
575 }
576 }
577
578 pub fn with_min_level(mut self, level: LogLevel) -> Self {
579 self.config.min_level = level;
580 self
581 }
582
583 pub fn with_notifications(mut self, send_notifications: bool) -> Self {
584 self.config.send_notifications = send_notifications;
585 self
586 }
587
588 pub fn with_structured_output(mut self, structured: bool) -> Self {
589 self.config.structured_output = structured;
590 self
591 }
592
593 pub fn with_max_message_length(mut self, length: usize) -> Self {
594 self.config.max_message_length = length;
595 self
596 }
597
598 pub fn with_timestamps(mut self, include: bool) -> Self {
599 self.config.include_timestamps = include;
600 self
601 }
602
603 pub fn with_logger_name(mut self, name: String) -> Self {
604 self.config.logger_name = Some(name);
605 self
606 }
607
608 pub fn build(self) -> LoggerConfig {
609 self.config
610 }
611}
612
613impl Default for ContextLogger {
614 fn default() -> Self {
615 Self::new()
616 }
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[tokio::test]
624 async fn test_context_creation() {
625 let ctx = Context::new()
626 .with_session_id("session-123".to_string())
627 .with_request_id("request-456".to_string())
628 .with_metadata("key".to_string(), serde_json::json!("value"));
629
630 assert_eq!(ctx.session_id(), Some("session-123"));
631 assert_eq!(ctx.request_id(), Some("request-456"));
632 assert_eq!(ctx.get_metadata("key"), Some(&serde_json::json!("value")));
633 }
634
635 #[tokio::test]
636 async fn test_context_logging() {
637 let ctx = Context::new().with_request_id("test-request".to_string());
638
639 ctx.log_info("Test info message").await.unwrap();
641 ctx.log_warn("Test warning message").await.unwrap();
642 ctx.log_error("Test error message").await.unwrap();
643 }
644
645 #[tokio::test]
646 async fn test_context_progress() {
647 let ctx = Context::new();
648
649 ctx.progress("Starting operation", 0.0, Some(1.0))
651 .await
652 .unwrap();
653 ctx.progress("Halfway done", 0.5, Some(1.0)).await.unwrap();
654 ctx.progress("Completed", 1.0, Some(1.0)).await.unwrap();
655
656 ctx.progress("Indeterminate progress", 0.3, None)
658 .await
659 .unwrap();
660 }
661}